Unverified Commit 1554066c authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

delay filter init; remove `*args` (#1369)

* delay filter init; remove `*args`

* bugfix

* optimize

* type hint
parent 7fc43656
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import Callable, Iterable, List, Union
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
...@@ -14,13 +14,13 @@ class Filter(ABC): ...@@ -14,13 +14,13 @@ class Filter(ABC):
""" """
def __init__(self, *args, **kwargs) -> None: def __init__(self, **kwargs) -> None:
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
@abstractmethod @abstractmethod
def apply(self, resps, docs): def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
""" """
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g. Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...@@ -40,7 +40,7 @@ class FilterEnsemble: ...@@ -40,7 +40,7 @@ class FilterEnsemble:
""" """
name: str name: str
filters: List[Filter] filters: List[Callable[[], Filter]]
def apply(self, instances: List[Instance]) -> None: def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
...@@ -48,7 +48,7 @@ class FilterEnsemble: ...@@ -48,7 +48,7 @@ class FilterEnsemble:
for f in self.filters: for f in self.filters:
# apply filters in sequence # apply filters in sequence
resps = f.apply(resps, docs) resps = f().apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances. # add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name. # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
......
...@@ -4,7 +4,12 @@ from typing import Literal, Tuple ...@@ -4,7 +4,12 @@ from typing import Literal, Tuple
@dataclass @dataclass
class Instance: class Instance:
request_type: Literal["loglikelihood", "loglikelihood_rolling", "generate_until"] request_type: Literal[
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
]
doc: dict doc: dict
arguments: tuple arguments: tuple
idx: int idx: int
......
...@@ -74,7 +74,12 @@ class TaskConfig(dict): ...@@ -74,7 +74,12 @@ class TaskConfig(dict):
num_fewshot: int = None num_fewshot: int = None
# scoring options # scoring options
metric_list: list = None metric_list: list = None
output_type: str = "generate_until" output_type: Literal[
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
] = "generate_until"
generation_kwargs: dict = None generation_kwargs: dict = None
repeats: int = 1 repeats: int = 1
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
......
from typing import List from typing import List, Union
from functools import partial
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from . import selection from . import selection
...@@ -22,7 +23,7 @@ FILTER_REGISTRY = { ...@@ -22,7 +23,7 @@ FILTER_REGISTRY = {
} }
def get_filter(filter_name): def get_filter(filter_name: str) -> Union[type, str]:
if filter_name in FILTER_REGISTRY: if filter_name in FILTER_REGISTRY:
return FILTER_REGISTRY[filter_name] return FILTER_REGISTRY[filter_name]
else: else:
...@@ -38,10 +39,9 @@ def build_filter_ensemble( ...@@ -38,10 +39,9 @@ def build_filter_ensemble(
filters = [] filters = []
for function, kwargs in components: for function, kwargs in components:
if kwargs is None: if kwargs is None:
f = get_filter(function)() kwargs = {}
else:
# create a filter given its name in the registry # create a filter given its name in the registry
f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly f = partial(get_filter(function), **kwargs)
# add the filter as a pipeline step # add the filter as a pipeline step
filters.append(f) filters.append(f)
......
...@@ -17,12 +17,14 @@ class TakeFirstFilter(Filter): ...@@ -17,12 +17,14 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter): class TakeKFilter(Filter):
def __init__(self, *args, **kwargs) -> None: def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
super().__init__(*args, **kwargs) super().__init__(**kwargs)
def apply(self, resps, docs): def apply(self, resps, docs):
# need resp to be subscriptable to check below
resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k # check we have at least k responses per doc, else we can't take the first k
assert ( assert (
len(resps[0]) >= self.k len(resps[0]) >= self.k
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment