Commit 51ab86ff authored by Baber's avatar Baber
Browse files

test

parent 5b8a7506
import abc import abc
import ast import ast
import collections
import logging import logging
import random import random
import re import re
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from itertools import groupby
from operator import attrgetter
from typing import ( from typing import (
Any, Any,
Dict, Dict,
...@@ -1778,7 +1780,7 @@ class ConfigurableTask(Task): ...@@ -1778,7 +1780,7 @@ class ConfigurableTask(Task):
rank=1, rank=1,
limit=None, limit=None,
world_size=1, world_size=1,
) -> Optional[dict[str, list[dict]]]: ) -> list[MetricResult]:
"""Calculate metrics for all datapoints in the task. """Calculate metrics for all datapoints in the task.
Args: Args:
...@@ -1794,7 +1796,6 @@ class ConfigurableTask(Task): ...@@ -1794,7 +1796,6 @@ class ConfigurableTask(Task):
""" """
if not self._instances: if not self._instances:
return return
from collections import defaultdict
if filter_keys is None: if filter_keys is None:
filter_keys = ( filter_keys = (
...@@ -1808,7 +1809,8 @@ class ConfigurableTask(Task): ...@@ -1808,7 +1809,8 @@ class ConfigurableTask(Task):
instances_by_doc_id = defaultdict(list) instances_by_doc_id = defaultdict(list)
for instance in self.instances: for instance in self.instances:
instances_by_doc_id[instance.doc_id].append(instance) instances_by_doc_id[instance.doc_id].append(instance)
all_metrics = collections.defaultdict(list) # all_metrics = collections.defaultdict(list)
all_metrics = []
for filter_key in filter_keys: for filter_key in filter_keys:
doc_iterator = self.doc_iterator( doc_iterator = self.doc_iterator(
rank=rank, rank=rank,
...@@ -1835,12 +1837,25 @@ class ConfigurableTask(Task): ...@@ -1835,12 +1837,25 @@ class ConfigurableTask(Task):
else [req.filtered_resps[filter_key]] else [req.filtered_resps[filter_key]]
) )
] ]
all_metrics[filter_key].append( all_metrics.append(
MetricResult(scores=metrics, doc_id=doc_id, filter_key=filter_key) MetricResult(scores=metrics, doc_id=doc_id, filter_key=filter_key)
) )
return all_metrics return all_metrics
@staticmethod
def compute_agg_metrics(self, metric_results: list[MetricResult]):
y_sorted = sorted(metric_results, key=attrgetter("filter_key", "metric_name"))
groups = {
key: list(
map(list, zip(*((d[it.metric_name] for d in it.scores) for it in g)))
)
for key, g in groupby(y_sorted, key=attrgetter("filter_key", "metric_name"))
}
return groups
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
OUTPUT_TYPE = "loglikelihood" OUTPUT_TYPE = "loglikelihood"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment