Commit 3fd12675 authored by Baber's avatar Baber
Browse files

type hints;

parent 46654b3d
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Iterable, List, Union
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
...@@ -20,7 +20,9 @@ class Filter(ABC): ...@@ -20,7 +20,9 @@ class Filter(ABC):
""" """
@abstractmethod @abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
""" """
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g. Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...@@ -40,9 +42,9 @@ class FilterEnsemble: ...@@ -40,9 +42,9 @@ class FilterEnsemble:
""" """
name: str name: str
filters: List[Callable[[], Filter]] filters: list[Callable[[], Filter]]
def apply(self, instances: List[Instance]) -> None: def apply(self, instances: list[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs) resps, docs = list(resps), list(docs)
......
import re import re
import sys import sys
import unicodedata import unicodedata
from collections.abc import Iterable
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter from lm_eval.api.registry import register_filter
...@@ -30,7 +31,9 @@ class RegexFilter(Filter): ...@@ -30,7 +31,9 @@ class RegexFilter(Filter):
self.group_select = group_select self.group_select = group_select
self.fallback = fallback self.fallback = fallback
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -57,57 +60,13 @@ class RegexFilter(Filter): ...@@ -57,57 +60,13 @@ class RegexFilter(Filter):
return filtered_resps return filtered_resps
@register_filter("regex_pos")
class POSFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def extract_tagged_tokens(text):
# Extract tagged tokens list from text input using regex
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result):
pos_tags = []
if isinstance(result, str):
result = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result)
return pos_tags if pos_tags else self.fallback
def filter_set(inst):
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_set(x), resps)
return filtered_resps
@register_filter("remove_whitespace") @register_filter("remove_whitespace")
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
"""Filters out leading whitespace from responses.""" """Filters out leading whitespace from responses."""
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
...@@ -152,7 +111,9 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -152,7 +111,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment