Commit 023bfe0d authored by Baber's avatar Baber
Browse files

cleanup

parent 49bfaf68
...@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter): ...@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter):
name = "track_decontamination" name = "track_decontamination"
def __init__(self, path) -> None: def __init__(self, path, **kwargs) -> None:
""" """
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id) should further cache result on a given (task_name, doc_id)
""" """
super().__init__(**kwargs)
self._decontam_results = None self._decontam_results = None
def apply(self, resps, docs) -> None: def apply(self, resps, docs) -> None:
......
...@@ -20,11 +20,13 @@ class RegexFilter(Filter): ...@@ -20,11 +20,13 @@ class RegexFilter(Filter):
regex_pattern: str = r"#### (\-?[0-9\.\,]+)", regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select: int = 0, group_select: int = 0,
fallback: str = "[invalid]", fallback: str = "[invalid]",
**kwargs,
) -> None: ) -> None:
""" """
pass a string `regex` to run `re.compile(r"regex")` on. pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located. `fallback` defines the output returned if no matches for the regex are located.
""" """
super().__init__(**kwargs)
self.regex_pattern = regex_pattern self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern) self.regex = re.compile(regex_pattern)
self.group_select = group_select self.group_select = group_select
...@@ -66,11 +68,13 @@ class POSFilter(Filter): ...@@ -66,11 +68,13 @@ class POSFilter(Filter):
regex_pattern: str = r"\['(.*?)'\]", regex_pattern: str = r"\['(.*?)'\]",
group_select=0, group_select=0,
fallback=None, fallback=None,
**kwargs,
) -> None: ) -> None:
""" """
pass a string `regex` to run `re.compile(r"regex")` on. pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located. `fallback` defines the output returned if no matches for the regex are located.
""" """
super().__init__(**kwargs)
if fallback is None: if fallback is None:
fallback = ["invalid"] fallback = ["invalid"]
self.regex_pattern = regex_pattern self.regex_pattern = regex_pattern
......
...@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter): ...@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter): class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
super().__init__(**kwargs) super().__init__(**kwargs)
def apply(self, resps, docs): def apply(self, resps, docs):
......
...@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter ...@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter
@register_filter("lowercase") @register_filter("lowercase")
class LowercaseFilter(Filter): class LowercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs): def apply(self, resps, docs):
def filter_set(inst): def filter_set(inst):
return [resp.lower() for resp in inst] return [resp.lower() for resp in inst]
...@@ -18,9 +15,6 @@ class LowercaseFilter(Filter): ...@@ -18,9 +15,6 @@ class LowercaseFilter(Filter):
@register_filter("uppercase") @register_filter("uppercase")
class UppercaseFilter(Filter): class UppercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs): def apply(self, resps, docs):
def filter_set(inst): def filter_set(inst):
return [resp.upper() for resp in inst] return [resp.upper() for resp in inst]
...@@ -31,6 +25,7 @@ class UppercaseFilter(Filter): ...@@ -31,6 +25,7 @@ class UppercaseFilter(Filter):
@register_filter("map") @register_filter("map")
class MapFilter(Filter): class MapFilter(Filter):
def __init__(self, mapping_dict: dict = None, default_value=None) -> None: def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
super().__init__()
""" """
Initializes the MapFilter with a given mapping dictionary and default value. Initializes the MapFilter with a given mapping dictionary and default value.
...@@ -60,9 +55,6 @@ class MapFilter(Filter): ...@@ -60,9 +55,6 @@ class MapFilter(Filter):
@register_filter("format_span") @register_filter("format_span")
class SPANFilter(Filter): class SPANFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs): def apply(self, resps, docs):
def format_ner_text(text): def format_ner_text(text):
label_dict = { label_dict = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment