Commit f8920c74 authored by Baber's avatar Baber
Browse files

add fallback regex to regex filter

parent 5707ae55
import re
import sys
import unicodedata
from typing import Union
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
......@@ -20,6 +21,8 @@ class RegexFilter(Filter):
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select: int = 0,
fallback: str = "[invalid]",
fallback_regex: list[str] = None,
fallback_regex_group_select: list[int] = None,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
......@@ -27,6 +30,14 @@ class RegexFilter(Filter):
"""
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.fallback_regex = (
[re.compile(r) for r in fallback_regex] if fallback_regex else None
)
self.fallback_regex_group_select = (
fallback_regex_group_select
if fallback_regex_group_select
else group_select * len(fallback_regex)
)
self.group_select = group_select
self.fallback = fallback
......@@ -35,27 +46,45 @@ class RegexFilter(Filter):
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def filter_set(inst):
filtered = []
for resp in inst:
match = self.regex.findall(resp)
def process_match(
match: Union[list[str], list[tuple[str, ...]], tuple[str, ...]],
group_select: int = self.group_select,
fallback: str = self.fallback,
) -> str:
"""Helper function to process regex match results"""
if not match:
return fallback
match = match[group_select]
if isinstance(match, tuple):
# Filter out empty strings and get first non-empty match if it exists
valid_matches = [m for m in match if m]
return valid_matches[0].strip() if valid_matches else fallback
return match.strip()
def try_fallback_regex(resp: str) -> str:
"""Helper function to attempt fallback regex patterns"""
for regex, group_select in zip(
self.fallback_regex, self.fallback_regex_group_select
):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m]
if match:
match = match[0]
else:
match = self.fallback
match = match.strip()
else:
match = self.fallback
filtered.append(match)
return filtered
return process_match(match, group_select)
return self.fallback
filtered_resps = list(map(lambda x: filter_set(x), resps))
def filter_response(resp: str) -> str:
"""Process a single response string"""
# Try primary regex first
match = self.regex.findall(resp)
if match:
return process_match(match)
return filtered_resps
# If primary regex fails and fallback_regex exists, try those
return try_fallback_regex(resp) if self.fallback_regex else self.fallback
return [
[filter_response(resp) for resp in response_set] for response_set in resps
]
@register_filter("remove_whitespace")
......
......@@ -13,7 +13,7 @@ filter_list:
filter:
- function: "regex"
regex_pattern: 'answer is \(?([ABCDEFGHIJ])\)?'
# regex_pattern: r".*[aA]nswer:\s*([A-J])",
fallback_regex: [r".*[aA]nswer:\s*([A-J])"]
- function: "take_first"
generation_kwargs:
until:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment