Commit 69e95b87 authored by Baber's avatar Baber
Browse files

fix filters and metrics

parent 72f5a5df
...@@ -3,6 +3,7 @@ from dataclasses import dataclass ...@@ -3,6 +3,7 @@ from dataclasses import dataclass
from typing import Callable, Iterable, List, Union from typing import Callable, Iterable, List, Union
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.schemas import GenerateOutput
class Filter(ABC): class Filter(ABC):
...@@ -45,7 +46,14 @@ class FilterEnsemble: ...@@ -45,7 +46,14 @@ class FilterEnsemble:
def apply(self, instances: List[Instance]) -> None: def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
# TODO: add backward # TODO: add backward
resps, docs = list([r.text] for y in resps for r in y), list(docs) # unwrap responses from GenerateOutput as the filters expect strings
resps = tuple(
[
item.text if isinstance(item, GenerateOutput) else str(item)
for item in sublist
]
for sublist in resps
)
for f in self.filters: for f in self.filters:
# apply filters in sequence # apply filters in sequence
......
...@@ -1769,7 +1769,7 @@ class ConfigurableTask(Task): ...@@ -1769,7 +1769,7 @@ class ConfigurableTask(Task):
def calculate_metrics( def calculate_metrics(
self, instances_by_doc_id, filter_key, samples, rank, limit, world_size self, instances_by_doc_id, filter_key, samples, rank, limit, world_size
): ) -> list[list[dict]]:
"""Calculate metrics for all datapoints in the task. """Calculate metrics for all datapoints in the task.
Args: Args:
...@@ -1797,13 +1797,14 @@ class ConfigurableTask(Task): ...@@ -1797,13 +1797,14 @@ class ConfigurableTask(Task):
# doc_id_true = indices[doc_id] if indices else doc_id # doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id] requests = instances_by_doc_id[doc_id]
metrics = [ metrics: list[list[dict]] = [
self.process_results(doc, response) self.process_results(doc, response)
for req in requests for req in requests
for response in req.filtered_resps[filter_key] for response in req.filtered_resps[filter_key]
] ]
all_metrics.extend(metrics) # TODO: This turns metrics into a list of lists of dicts rather than flat list.
all_metrics.append(metrics)
return all_metrics return all_metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment