Commit f8920c74 authored by Baber's avatar Baber
Browse files

add fallback regex to regex filter

parent 5707ae55
import re import re
import sys import sys
import unicodedata import unicodedata
from typing import Union
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter from lm_eval.api.registry import register_filter
...@@ -20,6 +21,8 @@ class RegexFilter(Filter): ...@@ -20,6 +21,8 @@ class RegexFilter(Filter):
regex_pattern: str = r"#### (\-?[0-9\.\,]+)", regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select: int = 0, group_select: int = 0,
fallback: str = "[invalid]", fallback: str = "[invalid]",
fallback_regex: list[str] = None,
fallback_regex_group_select: list[int] = None,
) -> None: ) -> None:
""" """
pass a string `regex` to run `re.compile(r"regex")` on. pass a string `regex` to run `re.compile(r"regex")` on.
...@@ -27,6 +30,14 @@ class RegexFilter(Filter): ...@@ -27,6 +30,14 @@ class RegexFilter(Filter):
""" """
self.regex_pattern = regex_pattern self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern) self.regex = re.compile(regex_pattern)
self.fallback_regex = (
[re.compile(r) for r in fallback_regex] if fallback_regex else None
)
self.fallback_regex_group_select = (
fallback_regex_group_select
if fallback_regex_group_select
else group_select * len(fallback_regex)
)
self.group_select = group_select self.group_select = group_select
self.fallback = fallback self.fallback = fallback
...@@ -35,27 +46,45 @@ class RegexFilter(Filter): ...@@ -35,27 +46,45 @@ class RegexFilter(Filter):
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
# independently (and keep them a list.) # independently (and keep them a list.)
def filter_set(inst): def process_match(
filtered = [] match: Union[list[str], list[tuple[str, ...]], tuple[str, ...]],
for resp in inst: group_select: int = self.group_select,
match = self.regex.findall(resp) fallback: str = self.fallback,
) -> str:
"""Helper function to process regex match results"""
if not match:
return fallback
match = match[group_select]
if isinstance(match, tuple):
# Filter out empty strings and get first non-empty match if it exists
valid_matches = [m for m in match if m]
return valid_matches[0].strip() if valid_matches else fallback
return match.strip()
def try_fallback_regex(resp: str) -> str:
"""Helper function to attempt fallback regex patterns"""
for regex, group_select in zip(
self.fallback_regex, self.fallback_regex_group_select
):
match = regex.findall(resp)
if match: if match:
match = match[self.group_select] return process_match(match, group_select)
if isinstance(match, tuple): return self.fallback
match = [m for m in match if m]
if match:
match = match[0]
else:
match = self.fallback
match = match.strip()
else:
match = self.fallback
filtered.append(match)
return filtered
filtered_resps = list(map(lambda x: filter_set(x), resps)) def filter_response(resp: str) -> str:
"""Process a single response string"""
# Try primary regex first
match = self.regex.findall(resp)
if match:
return process_match(match)
return filtered_resps # If primary regex fails and fallback_regex exists, try those
return try_fallback_regex(resp) if self.fallback_regex else self.fallback
return [
[filter_response(resp) for resp in response_set] for response_set in resps
]
@register_filter("remove_whitespace") @register_filter("remove_whitespace")
......
...@@ -13,7 +13,7 @@ filter_list: ...@@ -13,7 +13,7 @@ filter_list:
filter: filter:
- function: "regex" - function: "regex"
regex_pattern: 'answer is \(?([ABCDEFGHIJ])\)?' regex_pattern: 'answer is \(?([ABCDEFGHIJ])\)?'
# regex_pattern: r".*[aA]nswer:\s*([A-J])", fallback_regex: [r".*[aA]nswer:\s*([A-J])"]
- function: "take_first" - function: "take_first"
generation_kwargs: generation_kwargs:
until: until:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment