Commit 023bfe0d authored by Baber's avatar Baber
Browse files

cleanup

parent 49bfaf68
......@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter):
name = "track_decontamination"
def __init__(self, path) -> None:
def __init__(self, path, **kwargs) -> None:
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id)
"""
super().__init__(**kwargs)
self._decontam_results = None
def apply(self, resps, docs) -> None:
......
......@@ -20,11 +20,13 @@ class RegexFilter(Filter):
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select: int = 0,
fallback: str = "[invalid]",
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
......@@ -66,11 +68,13 @@ class POSFilter(Filter):
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
......
......@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k")
super().__init__(**kwargs)
def apply(self, resps, docs):
......
......@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter
@register_filter("lowercase")
class LowercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.lower() for resp in inst]
......@@ -18,9 +15,6 @@ class LowercaseFilter(Filter):
@register_filter("uppercase")
class UppercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.upper() for resp in inst]
......@@ -31,6 +25,7 @@ class UppercaseFilter(Filter):
@register_filter("map")
class MapFilter(Filter):
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
super().__init__()
"""
Initializes the MapFilter with a given mapping dictionary and default value.
......@@ -60,9 +55,6 @@ class MapFilter(Filter):
@register_filter("format_span")
class SPANFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def format_ner_text(text):
label_dict = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment