Commit 51ab86ff authored by Baber's avatar Baber
Browse files

test

parent 5b8a7506
import abc
import ast
import collections
import logging
import random
import re
from collections import defaultdict
from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from inspect import getsource
from itertools import groupby
from operator import attrgetter
from typing import (
Any,
Dict,
......@@ -1778,7 +1780,7 @@ class ConfigurableTask(Task):
rank=1,
limit=None,
world_size=1,
) -> Optional[dict[str, list[dict]]]:
) -> list[MetricResult]:
"""Calculate metrics for all datapoints in the task.
Args:
......@@ -1794,7 +1796,6 @@ class ConfigurableTask(Task):
"""
if not self._instances:
return
from collections import defaultdict
if filter_keys is None:
filter_keys = (
......@@ -1808,7 +1809,8 @@ class ConfigurableTask(Task):
instances_by_doc_id = defaultdict(list)
for instance in self.instances:
instances_by_doc_id[instance.doc_id].append(instance)
all_metrics = collections.defaultdict(list)
# all_metrics = collections.defaultdict(list)
all_metrics = []
for filter_key in filter_keys:
doc_iterator = self.doc_iterator(
rank=rank,
......@@ -1835,12 +1837,25 @@ class ConfigurableTask(Task):
else [req.filtered_resps[filter_key]]
)
]
all_metrics[filter_key].append(
all_metrics.append(
MetricResult(scores=metrics, doc_id=doc_id, filter_key=filter_key)
)
return all_metrics
@staticmethod
def compute_agg_metrics(self, metric_results: list[MetricResult]):
y_sorted = sorted(metric_results, key=attrgetter("filter_key", "metric_name"))
groups = {
key: list(
map(list, zip(*((d[it.metric_name] for d in it.scores) for it in g)))
)
for key, g in groupby(y_sorted, key=attrgetter("filter_key", "metric_name"))
}
return groups
class MultipleChoiceTask(Task):
OUTPUT_TYPE = "loglikelihood"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment