Commit e30978c7 authored by Baber's avatar Baber
Browse files

add metric calulation method to configurable task

parent e9eb451e
......@@ -1757,6 +1757,46 @@ class ConfigurableTask(Task):
f"num_samples={len(self.eval_docs)})"
)
def calculate_metrics(
self, instances_by_doc_id, filter_key, samples, rank, limit, world_size
):
"""Calculate metrics for all datapoints in the task.
Args:
instances_by_doc_id (dict): Dictionary mapping doc_ids to lists of instances.
filter_key (str): The filter key to use for filtered responses.
samples (dict, optional): Dictionary of sample indices to evaluate.
rank (int): The process rank.
limit (int, optional): Limit on number of examples to evaluate.
world_size (int): Total number of processes.
Returns:
list: A list of metrics calculated for each document.
"""
all_metrics = []
# indices = samples.get(self.config.task, None) if samples is not None else None
doc_iterator = self.doc_iterator(
rank=rank,
limit=limit,
world_size=world_size,
# samples=indices,
)
for doc_id, doc in doc_iterator:
# doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
metrics = [
self.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
all_metrics.extend(metrics)
return all_metrics
class MultipleChoiceTask(Task):
OUTPUT_TYPE = "loglikelihood"
......
......@@ -596,56 +596,125 @@ def evaluate(
instances.sort(key=lambda x: x.idx)
# iterate over different filters used
for filter_key in task.instances[0].filtered_resps.keys():
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator(
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
)
for doc_id, doc in doc_iterator:
if indices:
doc_id_true = indices[doc_id]
else:
doc_id_true = doc_id
requests = instances_by_doc_id[doc_id]
metrics: list[dict] = [
task.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
if hasattr(task, "calculate_metrics"):
# Use the new method if it exists (ConfigurableTask)
metrics = task.calculate_metrics(
instances_by_doc_id=instances_by_doc_id,
filter_key=filter_key,
samples=samples,
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
)
# Add sample logging here too - similar to what's done in the else branch
if log_samples:
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": list({k for m in metrics for k in m.keys()}),
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator(
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
)
for doc_id, doc in doc_iterator:
doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
if requests: # Make sure there are requests for this doc_id
# Get the metrics for this document
doc_metrics = [
task.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": doc_metrics,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
task_output.logged_samples.append(example)
# Process all metrics returned from calculate_metrics
for x in metrics:
for metric, value in x.items():
task_output.sample_metrics[(metric, filter_key)].append(value)
else:
# Fall back to the original approach for non-ConfigurableTask instances
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator(
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
)
for doc_id, doc in doc_iterator:
if indices:
doc_id_true = indices[doc_id]
else:
doc_id_true = doc_id
requests = instances_by_doc_id[doc_id]
metrics: list[dict] = [
task.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
if log_samples:
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": metrics,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
example.update({"metrics": metrics})
task_output.logged_samples.append(example)
for x in metrics:
for metric, value in x.items():
task_output.sample_metrics[(metric, filter_key)].append(
value
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
example.update(metrics)
task_output.logged_samples.append(example)
for metric, value in metrics.items():
task_output.sample_metrics[(metric, filter_key)].append(value)
if WORLD_SIZE > 1:
# if multigpu, then gather data across all ranks to rank 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment