Unverified Commit f64e72f5 authored by Nikita Lozhnikov's avatar Nikita Lozhnikov Committed by GitHub
Browse files

Add filter registry decorator (#1750)

* Add register_filter decorator

* Add register_filter docs
parent 9b49556a
...@@ -155,6 +155,21 @@ Our final filter pipeline, "maj@8", does majority voting across the first 8 of t ...@@ -155,6 +155,21 @@ Our final filter pipeline, "maj@8", does majority voting across the first 8 of t
Thus, given the 64 responses from our LM on each document, we can report metrics on these responses in these 3 different ways, as defined by our filter pipelines. Thus, given the 64 responses from our LM on each document, we can report metrics on these responses in these 3 different ways, as defined by our filter pipelines.
### Adding a custom filter
Just like adding a custom model with `register_model` decorator one is able to do the same with filters, for example
```python
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
@register_filter("new_filter")
class NewFilter(Filter)
...
```
## Embedded Python Code ## Embedded Python Code
Use can use python functions for certain arguments by using the `!function` operator after the argument name followed by `<filename>.<pythonfunctionname>`. This feature can be used for the following arguments: Use can use python functions for certain arguments by using the `!function` operator after the argument name followed by `<filename>.<pythonfunctionname>`. This feature can be used for the following arguments:
......
...@@ -78,6 +78,7 @@ METRIC_REGISTRY = {} ...@@ -78,6 +78,7 @@ METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {} METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {} AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {} HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = { DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [ "loglikelihood": [
...@@ -170,3 +171,22 @@ def is_higher_better(metric_name) -> bool: ...@@ -170,3 +171,22 @@ def is_higher_better(metric_name) -> bool:
eval_logger.warning( eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!" f"higher_is_better not specified for metric '{metric_name}'!"
) )
def register_filter(name):
def decorate(cls):
if name in FILTER_REGISTRY:
eval_logger.info(
f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
)
FILTER_REGISTRY[name] = cls
return cls
return decorate
def get_filter(filter_name: str) -> type:
try:
return FILTER_REGISTRY[filter_name]
except KeyError:
eval_logger.warning(f"filter `{filter_name}` is not registered!")
from functools import partial from functools import partial
from typing import List, Union from typing import List
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter
from . import extraction, selection, transformation from . import extraction, selection, transformation
FILTER_REGISTRY = {
"take_first": selection.TakeFirstFilter,
"regex": extraction.RegexFilter,
"majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter,
"remove_whitespace": extraction.WhitespaceFilter,
"lowercase": transformation.LowercaseFilter,
"uppercase": transformation.UppercaseFilter,
"map": transformation.MapFilter,
"multi_choice_regex": extraction.MultiChoiceRegexFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
# "arg_max": selection.ArgMaxFilter,
}
def get_filter(filter_name: str) -> Union[type, str]:
if filter_name in FILTER_REGISTRY:
return FILTER_REGISTRY[filter_name]
else:
return filter_name
def build_filter_ensemble( def build_filter_ensemble(
filter_name: str, components: List[List[str]] filter_name: str, components: List[List[str]]
) -> FilterEnsemble: ) -> FilterEnsemble:
......
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
@register_filter("decontaminate")
class DecontaminationFilter(Filter): class DecontaminationFilter(Filter):
""" """
......
...@@ -3,8 +3,10 @@ import sys ...@@ -3,8 +3,10 @@ import sys
import unicodedata import unicodedata
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
@register_filter("regex")
class RegexFilter(Filter): class RegexFilter(Filter):
""" """ """ """
...@@ -49,6 +51,7 @@ class RegexFilter(Filter): ...@@ -49,6 +51,7 @@ class RegexFilter(Filter):
return filtered_resps return filtered_resps
@register_filter("remove_whitespace")
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
""" """ """ """
...@@ -71,6 +74,7 @@ class WhitespaceFilter(Filter): ...@@ -71,6 +74,7 @@ class WhitespaceFilter(Filter):
return filtered_resps return filtered_resps
@register_filter("multi_choice_regex")
class MultiChoiceRegexFilter(RegexFilter): class MultiChoiceRegexFilter(RegexFilter):
""" """
A filter used to extract a model's answer on multiple choice questions with A filter used to extract a model's answer on multiple choice questions with
......
from collections import Counter from collections import Counter
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
@register_filter("take_first")
class TakeFirstFilter(Filter): class TakeFirstFilter(Filter):
def __init__(self) -> None: def __init__(self) -> None:
""" """
...@@ -16,6 +23,7 @@ class TakeFirstFilter(Filter): ...@@ -16,6 +23,7 @@ class TakeFirstFilter(Filter):
return map(lambda r: r[0], resps) return map(lambda r: r[0], resps)
@register_filter("take_first_k")
class TakeKFilter(Filter): class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
...@@ -32,6 +40,7 @@ class TakeKFilter(Filter): ...@@ -32,6 +40,7 @@ class TakeKFilter(Filter):
return map(lambda r: r[: self.k], resps) return map(lambda r: r[: self.k], resps)
@register_filter("majority_vote")
class MajorityVoteFilter(Filter): class MajorityVoteFilter(Filter):
def __init__(self) -> None: def __init__(self) -> None:
""" """
......
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
@register_filter("lowercase")
class LowercaseFilter(Filter): class LowercaseFilter(Filter):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
...@@ -12,6 +14,7 @@ class LowercaseFilter(Filter): ...@@ -12,6 +14,7 @@ class LowercaseFilter(Filter):
return [filter_set(resp) for resp in resps] return [filter_set(resp) for resp in resps]
@register_filter("uppercase")
class UppercaseFilter(Filter): class UppercaseFilter(Filter):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
...@@ -23,6 +26,7 @@ class UppercaseFilter(Filter): ...@@ -23,6 +26,7 @@ class UppercaseFilter(Filter):
return [filter_set(resp) for resp in resps] return [filter_set(resp) for resp in resps]
@register_filter("map")
class MapFilter(Filter): class MapFilter(Filter):
def __init__(self, mapping_dict: dict = None, default_value=None) -> None: def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment