Unverified Commit 1554066c authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

delay filter init; remove `*args` (#1369)

* delay filter init; remove `*args`

* bugfix

* optimize

* type hint
parent 7fc43656
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
from typing import Callable, Iterable, List, Union
from lm_eval.api.instance import Instance
......@@ -14,13 +14,13 @@ class Filter(ABC):
"""
def __init__(self, *args, **kwargs) -> None:
def __init__(self, **kwargs) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
@abstractmethod
def apply(self, resps, docs):
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
......@@ -40,7 +40,7 @@ class FilterEnsemble:
"""
name: str
filters: List[Filter]
filters: List[Callable[[], Filter]]
def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
......@@ -48,7 +48,7 @@ class FilterEnsemble:
for f in self.filters:
# apply filters in sequence
resps = f.apply(resps, docs)
resps = f().apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
......
......@@ -4,7 +4,12 @@ from typing import Literal, Tuple
@dataclass
class Instance:
request_type: Literal["loglikelihood", "loglikelihood_rolling", "generate_until"]
request_type: Literal[
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
]
doc: dict
arguments: tuple
idx: int
......
......@@ -74,7 +74,12 @@ class TaskConfig(dict):
num_fewshot: int = None
# scoring options
metric_list: list = None
output_type: str = "generate_until"
output_type: Literal[
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
] = "generate_until"
generation_kwargs: dict = None
repeats: int = 1
filter_list: Union[str, list] = None
......
from typing import List
from typing import List, Union
from functools import partial
from lm_eval.api.filter import FilterEnsemble
from . import selection
......@@ -22,7 +23,7 @@ FILTER_REGISTRY = {
}
def get_filter(filter_name):
def get_filter(filter_name: str) -> Union[type, str]:
if filter_name in FILTER_REGISTRY:
return FILTER_REGISTRY[filter_name]
else:
......@@ -38,10 +39,9 @@ def build_filter_ensemble(
filters = []
for function, kwargs in components:
if kwargs is None:
f = get_filter(function)()
else:
# create a filter given its name in the registry
f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly
kwargs = {}
# create a filter given its name in the registry
f = partial(get_filter(function), **kwargs)
# add the filter as a pipeline step
filters.append(f)
......
......@@ -17,12 +17,14 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k")
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
def apply(self, resps, docs):
# need resp to be subscriptable to check below
resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k
assert (
len(resps[0]) >= self.k
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment