Commit 69e95b87 authored by Baber's avatar Baber
Browse files

fix filters and metrics

parent 72f5a5df
......@@ -3,6 +3,7 @@ from dataclasses import dataclass
from typing import Callable, Iterable, List, Union
from lm_eval.api.instance import Instance
from lm_eval.api.schemas import GenerateOutput
class Filter(ABC):
......@@ -45,7 +46,14 @@ class FilterEnsemble:
def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
# TODO: add backward
resps, docs = list([r.text] for y in resps for r in y), list(docs)
# unwrap responses from GenerateOutput as the filters expect strings
resps = tuple(
[
item.text if isinstance(item, GenerateOutput) else str(item)
for item in sublist
]
for sublist in resps
)
for f in self.filters:
# apply filters in sequence
......
......@@ -1769,7 +1769,7 @@ class ConfigurableTask(Task):
def calculate_metrics(
self, instances_by_doc_id, filter_key, samples, rank, limit, world_size
):
) -> list[list[dict]]:
"""Calculate metrics for all datapoints in the task.
Args:
......@@ -1797,13 +1797,14 @@ class ConfigurableTask(Task):
# doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
metrics = [
metrics: list[list[dict]] = [
self.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
all_metrics.extend(metrics)
# TODO: This turns metrics into a list of lists of dicts rather than flat list.
all_metrics.append(metrics)
return all_metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment