Commit 90b261e5 authored by Chris's avatar Chris
Browse files

Add transformation filters: lowercase, uppercase, map

parent 5b26b3b0
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from . import selection from . import selection
from . import extraction from . import extraction
from . import transformation
FILTER_REGISTRY = { FILTER_REGISTRY = {
...@@ -9,6 +10,9 @@ FILTER_REGISTRY = { ...@@ -9,6 +10,9 @@ FILTER_REGISTRY = {
"majority_vote": selection.MajorityVoteFilter, "majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter, "take_first_k": selection.TakeKFilter,
"remove_whitespace": extraction.WhitespaceFilter, "remove_whitespace": extraction.WhitespaceFilter,
"lowercase": transformation.LowercaseFilter,
"uppercase": transformation.UppercaseFilter,
"map": transformation.MapFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward, # that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference. # or should implement different filters for different ways of handling a reward model's inference.
......
from lm_eval.api.filter import Filter
class LowercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.lower() for resp in inst]
return [filter_set(resp) for resp in resps]
class UppercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.upper() for resp in inst]
return [filter_set(resp) for resp in resps]
class MapFilter(Filter):
def __init__(self, mapping_dict: dict = {}, default_value = None) -> None:
"""
Initializes the MapFilter with a given mapping dictionary and default value.
Args:
- mapping_dict (dict): A dictionary containing the key-value mappings.
Default is an empty dictionary.
- default_value (Any): The value to be returned when a key is not found in the mapping_dict.
Default is None.
Example:
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
"""
assert isinstance(mapping_dict, dict), "Provided mapping_dict is not a dictionary"
self.mapping_dict = mapping_dict
self.default_value = default_value
def apply(self, resps, docs):
def filter_set(inst):
return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
return [filter_set(resp) for resp in resps]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment