Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
81f84653
Commit
81f84653
authored
Feb 11, 2025
by
Baber
Browse files
refactor multichoiceregex
parent
f8920c74
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
199 additions
and
82 deletions
+199
-82
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+197
-82
lm_eval/tasks/gpqa/cot_zeroshot/_gpqa_cot_zeroshot_yaml
lm_eval/tasks/gpqa/cot_zeroshot/_gpqa_cot_zeroshot_yaml
+2
-0
No files found.
lm_eval/filters/extraction.py
View file @
81f84653
import
re
import
sys
import
unicodedata
import
string
from
typing
import
Union
from
lm_eval.api.filter
import
Filter
...
...
@@ -106,112 +105,228 @@ class WhitespaceFilter(Filter):
@
register_filter
(
"multi_choice_regex"
)
class
MultiChoiceRegexFilter
(
RegexFilter
):
"""
A filter used to extract a model's answer on multiple choice questions with
letter answers. assumes each document has a "choices" field
containing the list of answer choices and that the answer label symbols
are of the form (A), (B), (C), ... or A, B, C.
"""A filter for extracting multiple choice answers from text responses.
This filter processes responses in the following order:
1. Full text matches of answer choices (e.g., "The earth is round" -> "A")
2. Letter-based answers in various formats (e.g., "(A)", "A:", "Answer: A")
Args:
regex_pattern (str, optional): Custom regex pattern for matching. If None, uses default pattern.
group_select (int, default=0): Which regex group to select from matches.
fallback (str, default="[invalid]"): Value to return when no match is found.
ignore_case (bool, default=True): Whether to ignore case when matching.
ignore_punctuation (bool, default=False): Whether to ignore punctuation when matching.
regexes_to_ignore (list, optional): List of regex patterns to remove from text before matching.
max_choices (int, default=4): Maximum number of choices to consider (A-D).
choices_field (str, default="choices"): Field name or dot path to get choices from document.
format_style (str, default="plain"): Output format style ("plain" for "A", "parens" for "(A)").
Examples:
>>> filter = MultiChoiceRegexFilter(format_style="parens", choices_field="choices")
>>> doc = {"choices": ["The earth is round", "The earth is flat"]}
>>> responses = ["The earth is round", "Answer: B", "(A)"]
>>> filter.apply([responses], [doc])
[[["(A)", "(B)", "(A)"]]]
# With nested choices
>>> doc = {"metadata": {"question": {"choices": ["True", "False"]}}}
>>> filter = MultiChoiceRegexFilter(choices_field="metadata.question.choices")
# With custom format
>>> filter = MultiChoiceRegexFilter(format_style="plain") # Returns "A" instead of "(A)"
"""
def
__init__
(
self
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
=
0
,
regex_pattern
:
str
=
None
,
group_select
:
int
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
ignore_case
:
bool
=
True
,
ignore_punctuation
:
bool
=
False
,
regexes_to_ignore
:
list
=
None
,
max_choices
:
int
=
4
,
# A-Z
choices_field
:
str
=
"choices"
,
format_style
:
str
=
"plain"
,
)
->
None
:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super
().
__init__
(
regex_pattern
,
group_select
,
fallback
)
self
.
ignore_case
=
ignore_case
self
.
ignore_punctuation
=
ignore_punctuation
self
.
regexes_to_ignore
=
regexes_to_ignore
self
.
max_choices
=
max_choices
self
.
choices_field
=
choices_field
self
.
format_style
=
format_style
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of t
hes
e
(
same input/target response sets)
# independently (and keep them a list.)
# If no custom pattern, create comprehensive letter pattern
if
regex_pattern
is
None
:
letters
=
""
.
join
([
chr
(
ord
(
"A"
)
+
i
)
for
i
in
range
(
max_choices
)])
# Matc
hes (
A), A:, Answer: A etc.
regex_pattern
=
rf
"(?:\(([
{
letters
}
])\))|(?:(?:answer|choice|option)?:?\s*([
{
letters
}
])(?:\s|$))"
def
find_match
(
regex
,
resp
,
convert_dict
=
{}):
match
=
regex
.
findall
(
resp
)
if
match
:
match
=
match
[
self
.
group_select
]
if
isinstance
(
match
,
tuple
):
match
=
[
m
for
m
in
match
if
m
][
0
]
match
=
match
.
strip
()
if
match
and
match
in
convert_dict
:
match
=
convert_dict
[
match
]
return
match
punct_tbl
=
dict
.
fromkeys
(
i
for
i
in
range
(
sys
.
maxunicode
)
if
unicodedata
.
category
(
chr
(
i
)).
startswith
(
"P"
)
)
super
().
__init__
(
regex_pattern
,
group_select
,
fallback
)
def
filter_ignores
(
st
):
if
self
.
regexes_to_ignore
is
not
None
:
for
s
in
self
.
regexes_to_ignore
:
st
=
re
.
sub
(
s
,
""
,
st
)
def
_format_letter
(
self
,
letter
:
str
)
->
str
:
"""Format a letter based on format_style setting"""
if
self
.
format_style
==
"parens"
:
return
f
"(
{
letter
}
)"
elif
self
.
format_style
==
"plain"
:
return
letter
# Add more format styles here as needed:
# elif self.format_style == "brackets":
# return f"[{letter}]"
# elif self.format_style == "numbered":
# return f"{ord(letter) - ord('A') + 1}"
else
:
return
f
"(
{
letter
}
)"
def
_filter_text
(
self
,
text
:
str
)
->
str
:
"""Apply text filtering rules (case, punctuation, regex ignores)"""
if
self
.
regexes_to_ignore
is
not
None
:
for
pattern
in
self
.
regexes_to_ignore
:
text
=
re
.
sub
(
pattern
,
""
,
text
)
if
self
.
ignore_case
:
text
=
text
.
lower
()
if
self
.
ignore_punctuation
:
text
=
text
.
translate
(
str
.
maketrans
(
""
,
""
,
string
.
punctuation
))
return
text
.
strip
()
def
_build_choice_patterns
(
self
,
choices
:
list
[
str
])
->
tuple
:
"""
Build regex patterns and conversion maps for both full text
and letter-based answers.
"""
# For matching full text of choices
choice_patterns
=
[]
choice_to_letter
=
{}
if
self
.
ignore_case
:
st
=
st
.
lower
()
# For matching letter answers
letter_map
=
{}
# Maps raw letters to (A) format
if
self
.
ignore_punctuation
:
# https://stackoverflow.com/a/266162
st
=
st
.
translate
(
punct_tbl
)
return
st
for
i
,
choice
in
enumerate
(
choices
):
if
i
>=
self
.
max_choices
:
break
filtered_resps
=
[]
# Get the letter for this choice (A, B, C, etc)
letter
=
chr
(
ord
(
"A"
)
+
i
)
formatted_letter
=
self
.
_format_letter
(
letter
)
for
r
,
doc
in
zip
(
resps
,
docs
):
fallback_regexes
=
[]
choice_to_alpha
=
{}
next_alpha
=
"A"
# Process the choice text
processed_choice
=
self
.
_filter_text
(
choice
)
without_paren_fallback_regexes
=
[]
without_paren_to_target
=
{}
# Add to full text matching
choice_patterns
.
append
(
re
.
escape
(
processed_choice
))
choice_to_letter
[
processed_choice
]
=
formatted_letter
choices
=
doc
[
"choices"
]
for
c
in
choices
:
m
=
filter_ignores
(
c
.
strip
())
fallback_regexes
.
append
(
f
"
{
re
.
escape
(
m
)
}
"
)
choice_to_alpha
[
m
]
=
f
"(
{
next_alpha
}
)"
# Add to letter matching
letter_map
[
letter
]
=
formatted_letter
without_paren_fallback_regexes
.
append
(
next_alpha
)
without_paren_to_target
[
next_alpha
]
=
f
"(
{
next_alpha
}
)"
# Create regex for full text matches
full_text_pattern
=
"|"
.
join
(
choice_patterns
)
if
choice_patterns
else
"(?!
)"
next_alpha
=
chr
(
ord
(
next_alpha
)
+
1
)
fallback_regex
=
re
.
compile
(
"|"
.
join
(
fallback_regexes
))
without_paren_fallback_regex
=
"|"
.
join
(
without_paren_fallback_regexes
)
without_paren_fallback_regex
=
re
.
compile
(
rf
":[\s]*(
{
without_paren_fallback_regex
}
)"
# Create regex for letter matches (: A, (A), etc)
# If no choices given, use default A-Z range based on max_choices
if
not
letter_map
:
letters
=
""
.
join
([
chr
(
ord
(
"A"
)
+
i
)
for
i
in
range
(
self
.
max_choices
)])
else
:
letters
=
""
.
join
(
letter_map
.
keys
())
letter_pattern
=
rf
"(?:\(([
{
letters
}
])\))|(?:(?:answer|choice|option)?:?\s*([
{
letters
}
])(?:\s|$))"
return
(
re
.
compile
(
full_text_pattern
),
re
.
compile
(
letter_pattern
),
choice_to_letter
,
letter_map
,
)
def
_get_choices
(
self
,
doc
:
dict
)
->
list
:
"""
Safely extract choices from the document using the specified field name.
Handles nested fields using dot notation (e.g., "metadata.choices").
Returns empty list if:
- doc is None or not a dict
- field doesn't exist
- field value is None
- field value is not a list
"""
if
doc
is
None
or
not
isinstance
(
doc
,
dict
):
return
[]
if
"."
in
self
.
choices_field
:
# Handle nested fields
fields
=
self
.
choices_field
.
split
(
"."
)
value
=
doc
for
field
in
fields
:
if
not
isinstance
(
value
,
dict
)
or
field
not
in
value
:
return
[]
value
=
value
[
field
]
if
value
is
None
:
return
[]
assert
isinstance
(
value
,
list
)
return
value
else
:
# Direct field access
value
=
doc
.
get
(
self
.
choices_field
)
if
value
is
None
:
return
[]
assert
isinstance
(
value
,
list
)
return
value
def
_find_match
(
self
,
regex
,
text
:
str
,
conversion_map
:
dict
=
None
)
->
Union
[
str
,
None
]:
"""Find regex matches and convert using the provided map if any."""
matches
=
regex
.
findall
(
text
)
if
not
matches
:
return
None
# Handle both single matches and tuple groups
match
=
matches
[
self
.
group_select
]
if
isinstance
(
match
,
tuple
):
# Take first non-empty group
match
=
next
((
m
for
m
in
match
if
m
),
None
)
if
match
and
conversion_map
:
return
conversion_map
.
get
(
match
,
match
)
return
match
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
filtered_resps
=
[]
for
responses
,
doc
in
zip
(
resps
,
docs
):
choices
=
self
.
_get_choices
(
doc
)
# Build patterns for both full text and letter matching
full_text_re
,
letter_re
,
choice_to_letter
,
letter_map
=
(
self
.
_build_choice_patterns
(
choices
)
)
filtered
=
[]
for
resp
in
r
:
match
=
find_match
(
self
.
regex
,
resp
)
for
resp
in
responses
:
match
=
None
# Try the custom regex pattern first (if provided)
if
self
.
regex_pattern
!=
letter_re
.
pattern
:
match
=
self
.
_find_match
(
self
.
regex
,
resp
)
if
not
match
:
match
=
find_match
(
fallback_regex
,
filter_ignores
(
resp
),
choice_to_alpha
# Try matching full text of choices
processed_resp
=
self
.
_filter_text
(
resp
)
match
=
self
.
_find_match
(
full_text_re
,
processed_resp
,
choice_to_letter
)
if
not
match
:
match
=
find_match
(
without_paren_fallback_regex
,
resp
,
without_paren_to_target
)
if
not
match
:
match
=
self
.
fallback
filtered
.
append
(
match
)
# Try matching letter patterns
if
self
.
ignore_case
:
resp
=
resp
.
upper
()
match
=
self
.
_find_match
(
letter_re
,
resp
,
letter_map
)
filtered
.
append
(
match
if
match
else
self
.
fallback
)
filtered_resps
.
append
(
filtered
)
return
filtered_resps
lm_eval/tasks/gpqa/cot_zeroshot/_gpqa_cot_zeroshot_yaml
View file @
81f84653
...
...
@@ -21,6 +21,8 @@ filter_list:
ignore_case: true
ignore_punctuation: true
regex_pattern: "(\\([A-Z]\\))"
choices_field: "choices"
format_style: "parens"
- function: "take_first"
generation_kwargs:
until:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment