Commit 81f84653 authored by Baber's avatar Baber
Browse files

refactor multichoiceregex

parent f8920c74
import re
import sys
import unicodedata
import string
from typing import Union
from lm_eval.api.filter import Filter
......@@ -106,112 +105,228 @@ class WhitespaceFilter(Filter):
@register_filter("multi_choice_regex")
class MultiChoiceRegexFilter(RegexFilter):
"""
A filter used to extract a model's answer on multiple choice questions with
letter answers. assumes each document has a "choices" field
containing the list of answer choices and that the answer label symbols
are of the form (A), (B), (C), ... or A, B, C.
"""A filter for extracting multiple choice answers from text responses.
This filter processes responses in the following order:
1. Full text matches of answer choices (e.g., "The earth is round" -> "A")
2. Letter-based answers in various formats (e.g., "(A)", "A:", "Answer: A")
Args:
regex_pattern (str, optional): Custom regex pattern for matching. If None, uses default pattern.
group_select (int, default=0): Which regex group to select from matches.
fallback (str, default="[invalid]"): Value to return when no match is found.
ignore_case (bool, default=True): Whether to ignore case when matching.
ignore_punctuation (bool, default=False): Whether to ignore punctuation when matching.
regexes_to_ignore (list, optional): List of regex patterns to remove from text before matching.
max_choices (int, default=4): Maximum number of choices to consider (A-D).
choices_field (str, default="choices"): Field name or dot path to get choices from document.
format_style (str, default="plain"): Output format style ("plain" for "A", "parens" for "(A)").
Examples:
>>> filter = MultiChoiceRegexFilter(format_style="parens", choices_field="choices")
>>> doc = {"choices": ["The earth is round", "The earth is flat"]}
>>> responses = ["The earth is round", "Answer: B", "(A)"]
>>> filter.apply([responses], [doc])
[[["(A)", "(B)", "(A)"]]]
# With nested choices
>>> doc = {"metadata": {"question": {"choices": ["True", "False"]}}}
>>> filter = MultiChoiceRegexFilter(choices_field="metadata.question.choices")
# With custom format
>>> filter = MultiChoiceRegexFilter(format_style="plain") # Returns "A" instead of "(A)"
"""
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
regex_pattern: str = None,
group_select: int = 0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
ignore_case: bool = True,
ignore_punctuation: bool = False,
regexes_to_ignore: list = None,
max_choices: int = 4, # A-Z
choices_field: str = "choices",
format_style: str = "plain",
) -> None:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
self.max_choices = max_choices
self.choices_field = choices_field
self.format_style = format_style
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
# If no custom pattern, create comprehensive letter pattern
if regex_pattern is None:
letters = "".join([chr(ord("A") + i) for i in range(max_choices)])
# Matches (A), A:, Answer: A etc.
regex_pattern = rf"(?:\(([{letters}])\))|(?:(?:answer|choice|option)?:?\s*([{letters}])(?:\s|$))"
def find_match(regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match
punct_tbl = dict.fromkeys(
i
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
)
super().__init__(regex_pattern, group_select, fallback)
def filter_ignores(st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)
def _format_letter(self, letter: str) -> str:
"""Format a letter based on format_style setting"""
if self.format_style == "parens":
return f"({letter})"
elif self.format_style == "plain":
return letter
# Add more format styles here as needed:
# elif self.format_style == "brackets":
# return f"[{letter}]"
# elif self.format_style == "numbered":
# return f"{ord(letter) - ord('A') + 1}"
else:
return f"({letter})"
def _filter_text(self, text: str) -> str:
"""Apply text filtering rules (case, punctuation, regex ignores)"""
if self.regexes_to_ignore is not None:
for pattern in self.regexes_to_ignore:
text = re.sub(pattern, "", text)
if self.ignore_case:
text = text.lower()
if self.ignore_punctuation:
text = text.translate(str.maketrans("", "", string.punctuation))
return text.strip()
def _build_choice_patterns(self, choices: list[str]) -> tuple:
"""
Build regex patterns and conversion maps for both full text
and letter-based answers.
"""
# For matching full text of choices
choice_patterns = []
choice_to_letter = {}
if self.ignore_case:
st = st.lower()
# For matching letter answers
letter_map = {} # Maps raw letters to (A) format
if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(punct_tbl)
return st
for i, choice in enumerate(choices):
if i >= self.max_choices:
break
filtered_resps = []
# Get the letter for this choice (A, B, C, etc)
letter = chr(ord("A") + i)
formatted_letter = self._format_letter(letter)
for r, doc in zip(resps, docs):
fallback_regexes = []
choice_to_alpha = {}
next_alpha = "A"
# Process the choice text
processed_choice = self._filter_text(choice)
without_paren_fallback_regexes = []
without_paren_to_target = {}
# Add to full text matching
choice_patterns.append(re.escape(processed_choice))
choice_to_letter[processed_choice] = formatted_letter
choices = doc["choices"]
for c in choices:
m = filter_ignores(c.strip())
fallback_regexes.append(f"{re.escape(m)}")
choice_to_alpha[m] = f"({next_alpha})"
# Add to letter matching
letter_map[letter] = formatted_letter
without_paren_fallback_regexes.append(next_alpha)
without_paren_to_target[next_alpha] = f"({next_alpha})"
# Create regex for full text matches
full_text_pattern = "|".join(choice_patterns) if choice_patterns else "(?!)"
next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(
rf":[\s]*({without_paren_fallback_regex})"
# Create regex for letter matches (: A, (A), etc)
# If no choices given, use default A-Z range based on max_choices
if not letter_map:
letters = "".join([chr(ord("A") + i) for i in range(self.max_choices)])
else:
letters = "".join(letter_map.keys())
letter_pattern = rf"(?:\(([{letters}])\))|(?:(?:answer|choice|option)?:?\s*([{letters}])(?:\s|$))"
return (
re.compile(full_text_pattern),
re.compile(letter_pattern),
choice_to_letter,
letter_map,
)
def _get_choices(self, doc: dict) -> list:
"""
Safely extract choices from the document using the specified field name.
Handles nested fields using dot notation (e.g., "metadata.choices").
Returns empty list if:
- doc is None or not a dict
- field doesn't exist
- field value is None
- field value is not a list
"""
if doc is None or not isinstance(doc, dict):
return []
if "." in self.choices_field:
# Handle nested fields
fields = self.choices_field.split(".")
value = doc
for field in fields:
if not isinstance(value, dict) or field not in value:
return []
value = value[field]
if value is None:
return []
assert isinstance(value, list)
return value
else:
# Direct field access
value = doc.get(self.choices_field)
if value is None:
return []
assert isinstance(value, list)
return value
def _find_match(
self, regex, text: str, conversion_map: dict = None
) -> Union[str, None]:
"""Find regex matches and convert using the provided map if any."""
matches = regex.findall(text)
if not matches:
return None
# Handle both single matches and tuple groups
match = matches[self.group_select]
if isinstance(match, tuple):
# Take first non-empty group
match = next((m for m in match if m), None)
if match and conversion_map:
return conversion_map.get(match, match)
return match
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
filtered_resps = []
for responses, doc in zip(resps, docs):
choices = self._get_choices(doc)
# Build patterns for both full text and letter matching
full_text_re, letter_re, choice_to_letter, letter_map = (
self._build_choice_patterns(choices)
)
filtered = []
for resp in r:
match = find_match(self.regex, resp)
for resp in responses:
match = None
# Try the custom regex pattern first (if provided)
if self.regex_pattern != letter_re.pattern:
match = self._find_match(self.regex, resp)
if not match:
match = find_match(
fallback_regex, filter_ignores(resp), choice_to_alpha
# Try matching full text of choices
processed_resp = self._filter_text(resp)
match = self._find_match(
full_text_re, processed_resp, choice_to_letter
)
if not match:
match = find_match(
without_paren_fallback_regex, resp, without_paren_to_target
)
if not match:
match = self.fallback
filtered.append(match)
# Try matching letter patterns
if self.ignore_case:
resp = resp.upper()
match = self._find_match(letter_re, resp, letter_map)
filtered.append(match if match else self.fallback)
filtered_resps.append(filtered)
return filtered_resps
......@@ -21,6 +21,8 @@ filter_list:
ignore_case: true
ignore_punctuation: true
regex_pattern: "(\\([A-Z]\\))"
choices_field: "choices"
format_style: "parens"
- function: "take_first"
generation_kwargs:
until:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment