Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
81f84653
"include/ck/ck.hpp" did not exist on "6014185ac65e75f2a84cb67ef6ba83b48ae0fcb3"
Commit
81f84653
authored
Feb 11, 2025
by
Baber
Browse files
refactor multichoiceregex
parent
f8920c74
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
199 additions
and
82 deletions
+199
-82
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+197
-82
lm_eval/tasks/gpqa/cot_zeroshot/_gpqa_cot_zeroshot_yaml
lm_eval/tasks/gpqa/cot_zeroshot/_gpqa_cot_zeroshot_yaml
+2
-0
No files found.
lm_eval/filters/extraction.py
View file @
81f84653
import
re
import
re
import
sys
import
string
import
unicodedata
from
typing
import
Union
from
typing
import
Union
from
lm_eval.api.filter
import
Filter
from
lm_eval.api.filter
import
Filter
...
@@ -106,112 +105,228 @@ class WhitespaceFilter(Filter):
...
@@ -106,112 +105,228 @@ class WhitespaceFilter(Filter):
@
register_filter
(
"multi_choice_regex"
)
@
register_filter
(
"multi_choice_regex"
)
class
MultiChoiceRegexFilter
(
RegexFilter
):
class
MultiChoiceRegexFilter
(
RegexFilter
):
"""
"""A filter for extracting multiple choice answers from text responses.
A filter used to extract a model's answer on multiple choice questions with
letter answers. assumes each document has a "choices" field
This filter processes responses in the following order:
containing the list of answer choices and that the answer label symbols
1. Full text matches of answer choices (e.g., "The earth is round" -> "A")
are of the form (A), (B), (C), ... or A, B, C.
2. Letter-based answers in various formats (e.g., "(A)", "A:", "Answer: A")
Args:
regex_pattern (str, optional): Custom regex pattern for matching. If None, uses default pattern.
group_select (int, default=0): Which regex group to select from matches.
fallback (str, default="[invalid]"): Value to return when no match is found.
ignore_case (bool, default=True): Whether to ignore case when matching.
ignore_punctuation (bool, default=False): Whether to ignore punctuation when matching.
regexes_to_ignore (list, optional): List of regex patterns to remove from text before matching.
max_choices (int, default=4): Maximum number of choices to consider (A-D).
choices_field (str, default="choices"): Field name or dot path to get choices from document.
format_style (str, default="plain"): Output format style ("plain" for "A", "parens" for "(A)").
Examples:
>>> filter = MultiChoiceRegexFilter(format_style="parens", choices_field="choices")
>>> doc = {"choices": ["The earth is round", "The earth is flat"]}
>>> responses = ["The earth is round", "Answer: B", "(A)"]
>>> filter.apply([responses], [doc])
[[["(A)", "(B)", "(A)"]]]
# With nested choices
>>> doc = {"metadata": {"question": {"choices": ["True", "False"]}}}
>>> filter = MultiChoiceRegexFilter(choices_field="metadata.question.choices")
# With custom format
>>> filter = MultiChoiceRegexFilter(format_style="plain") # Returns "A" instead of "(A)"
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
regex_pattern
:
str
=
None
,
group_select
=
0
,
group_select
:
int
=
0
,
fallback
:
str
=
"[invalid]"
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_case
:
bool
=
True
,
ignore_punctuation
=
False
,
ignore_punctuation
:
bool
=
False
,
regexes_to_ignore
=
None
,
regexes_to_ignore
:
list
=
None
,
max_choices
:
int
=
4
,
# A-Z
choices_field
:
str
=
"choices"
,
format_style
:
str
=
"plain"
,
)
->
None
:
)
->
None
:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super
().
__init__
(
regex_pattern
,
group_select
,
fallback
)
self
.
ignore_case
=
ignore_case
self
.
ignore_case
=
ignore_case
self
.
ignore_punctuation
=
ignore_punctuation
self
.
ignore_punctuation
=
ignore_punctuation
self
.
regexes_to_ignore
=
regexes_to_ignore
self
.
regexes_to_ignore
=
regexes_to_ignore
self
.
max_choices
=
max_choices
self
.
choices_field
=
choices_field
self
.
format_style
=
format_style
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
# If no custom pattern, create comprehensive letter pattern
# here, we assume we have a list, in which each element is
if
regex_pattern
is
None
:
# a list of model responses for some particular input/target pair.
letters
=
""
.
join
([
chr
(
ord
(
"A"
)
+
i
)
for
i
in
range
(
max_choices
)])
# so we process each of t
hes
e
(
same input/target response sets)
# Matc
hes (
A), A:, Answer: A etc.
# independently (and keep them a list.)
regex_pattern
=
rf
"(?:\(([
{
letters
}
])\))|(?:(?:answer|choice|option)?:?\s*([
{
letters
}
])(?:\s|$))"
def
find_match
(
regex
,
resp
,
convert_dict
=
{}):
super
().
__init__
(
regex_pattern
,
group_select
,
fallback
)
match
=
regex
.
findall
(
resp
)
if
match
:
match
=
match
[
self
.
group_select
]
if
isinstance
(
match
,
tuple
):
match
=
[
m
for
m
in
match
if
m
][
0
]
match
=
match
.
strip
()
if
match
and
match
in
convert_dict
:
match
=
convert_dict
[
match
]
return
match
punct_tbl
=
dict
.
fromkeys
(
i
for
i
in
range
(
sys
.
maxunicode
)
if
unicodedata
.
category
(
chr
(
i
)).
startswith
(
"P"
)
)
def
filter_ignores
(
st
):
def
_format_letter
(
self
,
letter
:
str
)
->
str
:
if
self
.
regexes_to_ignore
is
not
None
:
"""Format a letter based on format_style setting"""
for
s
in
self
.
regexes_to_ignore
:
if
self
.
format_style
==
"parens"
:
st
=
re
.
sub
(
s
,
""
,
st
)
return
f
"(
{
letter
}
)"
elif
self
.
format_style
==
"plain"
:
return
letter
# Add more format styles here as needed:
# elif self.format_style == "brackets":
# return f"[{letter}]"
# elif self.format_style == "numbered":
# return f"{ord(letter) - ord('A') + 1}"
else
:
return
f
"(
{
letter
}
)"
def
_filter_text
(
self
,
text
:
str
)
->
str
:
"""Apply text filtering rules (case, punctuation, regex ignores)"""
if
self
.
regexes_to_ignore
is
not
None
:
for
pattern
in
self
.
regexes_to_ignore
:
text
=
re
.
sub
(
pattern
,
""
,
text
)
if
self
.
ignore_case
:
text
=
text
.
lower
()
if
self
.
ignore_punctuation
:
text
=
text
.
translate
(
str
.
maketrans
(
""
,
""
,
string
.
punctuation
))
return
text
.
strip
()
def
_build_choice_patterns
(
self
,
choices
:
list
[
str
])
->
tuple
:
"""
Build regex patterns and conversion maps for both full text
and letter-based answers.
"""
# For matching full text of choices
choice_patterns
=
[]
choice_to_letter
=
{}
if
self
.
ignore_case
:
# For matching letter answers
st
=
st
.
lower
()
letter_map
=
{}
# Maps raw letters to (A) format
if
self
.
ignore_punctuation
:
for
i
,
choice
in
enumerate
(
choices
):
# https://stackoverflow.com/a/266162
if
i
>=
self
.
max_choices
:
st
=
st
.
translate
(
punct_tbl
)
break
return
st
filtered_resps
=
[]
# Get the letter for this choice (A, B, C, etc)
letter
=
chr
(
ord
(
"A"
)
+
i
)
formatted_letter
=
self
.
_format_letter
(
letter
)
for
r
,
doc
in
zip
(
resps
,
docs
):
# Process the choice text
fallback_regexes
=
[]
processed_choice
=
self
.
_filter_text
(
choice
)
choice_to_alpha
=
{}
next_alpha
=
"A"
without_paren_fallback_regexes
=
[]
# Add to full text matching
without_paren_to_target
=
{}
choice_patterns
.
append
(
re
.
escape
(
processed_choice
))
choice_to_letter
[
processed_choice
]
=
formatted_letter
choices
=
doc
[
"choices"
]
# Add to letter matching
for
c
in
choices
:
letter_map
[
letter
]
=
formatted_letter
m
=
filter_ignores
(
c
.
strip
())
fallback_regexes
.
append
(
f
"
{
re
.
escape
(
m
)
}
"
)
choice_to_alpha
[
m
]
=
f
"(
{
next_alpha
}
)"
without_paren_fallback_regexes
.
append
(
next_alpha
)
# Create regex for full text matches
without_paren_to_target
[
next_alpha
]
=
f
"(
{
next_alpha
}
)"
full_text_pattern
=
"|"
.
join
(
choice_patterns
)
if
choice_patterns
else
"(?!
)"
next_alpha
=
chr
(
ord
(
next_alpha
)
+
1
)
# Create regex for letter matches (: A, (A), etc)
fallback_regex
=
re
.
compile
(
"|"
.
join
(
fallback_regexes
))
# If no choices given, use default A-Z range based on max_choices
without_paren_fallback_regex
=
"|"
.
join
(
without_paren_fallback_regexes
)
if
not
letter_map
:
without_paren_fallback_regex
=
re
.
compile
(
letters
=
""
.
join
([
chr
(
ord
(
"A"
)
+
i
)
for
i
in
range
(
self
.
max_choices
)])
rf
":[\s]*(
{
without_paren_fallback_regex
}
)"
else
:
letters
=
""
.
join
(
letter_map
.
keys
())
letter_pattern
=
rf
"(?:\(([
{
letters
}
])\))|(?:(?:answer|choice|option)?:?\s*([
{
letters
}
])(?:\s|$))"
return
(
re
.
compile
(
full_text_pattern
),
re
.
compile
(
letter_pattern
),
choice_to_letter
,
letter_map
,
)
def
_get_choices
(
self
,
doc
:
dict
)
->
list
:
"""
Safely extract choices from the document using the specified field name.
Handles nested fields using dot notation (e.g., "metadata.choices").
Returns empty list if:
- doc is None or not a dict
- field doesn't exist
- field value is None
- field value is not a list
"""
if
doc
is
None
or
not
isinstance
(
doc
,
dict
):
return
[]
if
"."
in
self
.
choices_field
:
# Handle nested fields
fields
=
self
.
choices_field
.
split
(
"."
)
value
=
doc
for
field
in
fields
:
if
not
isinstance
(
value
,
dict
)
or
field
not
in
value
:
return
[]
value
=
value
[
field
]
if
value
is
None
:
return
[]
assert
isinstance
(
value
,
list
)
return
value
else
:
# Direct field access
value
=
doc
.
get
(
self
.
choices_field
)
if
value
is
None
:
return
[]
assert
isinstance
(
value
,
list
)
return
value
def
_find_match
(
self
,
regex
,
text
:
str
,
conversion_map
:
dict
=
None
)
->
Union
[
str
,
None
]:
"""Find regex matches and convert using the provided map if any."""
matches
=
regex
.
findall
(
text
)
if
not
matches
:
return
None
# Handle both single matches and tuple groups
match
=
matches
[
self
.
group_select
]
if
isinstance
(
match
,
tuple
):
# Take first non-empty group
match
=
next
((
m
for
m
in
match
if
m
),
None
)
if
match
and
conversion_map
:
return
conversion_map
.
get
(
match
,
match
)
return
match
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
filtered_resps
=
[]
for
responses
,
doc
in
zip
(
resps
,
docs
):
choices
=
self
.
_get_choices
(
doc
)
# Build patterns for both full text and letter matching
full_text_re
,
letter_re
,
choice_to_letter
,
letter_map
=
(
self
.
_build_choice_patterns
(
choices
)
)
)
filtered
=
[]
filtered
=
[]
for
resp
in
r
:
for
resp
in
responses
:
match
=
find_match
(
self
.
regex
,
resp
)
match
=
None
# Try the custom regex pattern first (if provided)
if
self
.
regex_pattern
!=
letter_re
.
pattern
:
match
=
self
.
_find_match
(
self
.
regex
,
resp
)
if
not
match
:
if
not
match
:
match
=
find_match
(
# Try matching full text of choices
fallback_regex
,
filter_ignores
(
resp
),
choice_to_alpha
processed_resp
=
self
.
_filter_text
(
resp
)
match
=
self
.
_find_match
(
full_text_re
,
processed_resp
,
choice_to_letter
)
)
if
not
match
:
if
not
match
:
match
=
find_match
(
# Try matching letter patterns
without_paren_fallback_regex
,
resp
,
without_paren_to_target
if
self
.
ignore_case
:
)
resp
=
resp
.
upper
()
if
not
match
:
match
=
self
.
_find_match
(
letter_re
,
resp
,
letter_map
)
match
=
self
.
fallback
filtered
.
append
(
match
)
filtered
.
append
(
match
if
match
else
self
.
fallback
)
filtered_resps
.
append
(
filtered
)
filtered_resps
.
append
(
filtered
)
return
filtered_resps
return
filtered_resps
lm_eval/tasks/gpqa/cot_zeroshot/_gpqa_cot_zeroshot_yaml
View file @
81f84653
...
@@ -21,6 +21,8 @@ filter_list:
...
@@ -21,6 +21,8 @@ filter_list:
ignore_case: true
ignore_case: true
ignore_punctuation: true
ignore_punctuation: true
regex_pattern: "(\\([A-Z]\\))"
regex_pattern: "(\\([A-Z]\\))"
choices_field: "choices"
format_style: "parens"
- function: "take_first"
- function: "take_first"
generation_kwargs:
generation_kwargs:
until:
until:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment