Commit 65425258 authored by Baber's avatar Baber
Browse files

fix filter type hint

parent 708b160d
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterable, List, Union
from typing import Iterable, List
from lm_eval.api.instance import Instance
......@@ -20,7 +20,9 @@ class Filter(ABC):
"""
@abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
def apply(
self, resps: Iterable[list[str]], docs: List[dict]
) -> Iterable[list[str]]:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
......@@ -40,7 +42,7 @@ class FilterEnsemble:
"""
name: str
filters: List[Callable[[], Filter]]
filters: List[type[Filter]]
def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
......@@ -48,7 +50,10 @@ class FilterEnsemble:
for f in self.filters:
# apply filters in sequence
resps = f().apply(resps, docs)
try:
resps = f().apply(resps, docs)
except (TypeError, AttributeError):
resps = f().apply(list(resps), docs)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
......
from collections import Counter
from typing import Sequence
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
......@@ -16,7 +17,7 @@ class TakeFirstFilter(Filter):
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps, docs):
def apply(self, resps: Sequence, docs):
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
......@@ -30,7 +31,7 @@ class TakeKFilter(Filter):
super().__init__(**kwargs)
def apply(self, resps, docs):
def apply(self, resps: Sequence, docs):
# need resp to be subscriptable to check below
resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment