Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
3fd12675
Commit
3fd12675
authored
Jul 21, 2025
by
Baber
Browse files
type hints;
parent
46654b3d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
53 deletions
+16
-53
lm_eval/api/filter.py
lm_eval/api/filter.py
+6
-4
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+10
-49
No files found.
lm_eval/api/filter.py
View file @
3fd12675
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
,
Iterable
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Iterable
,
List
,
Union
from
lm_eval.api.instance
import
Instance
...
...
@@ -20,7 +20,9 @@ class Filter(ABC):
"""
@
abstractmethod
def
apply
(
self
,
resps
:
Union
[
List
,
Iterable
],
docs
:
List
[
dict
])
->
Iterable
:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...
...
@@ -40,9 +42,9 @@ class FilterEnsemble:
"""
name
:
str
filters
:
L
ist
[
Callable
[[],
Filter
]]
filters
:
l
ist
[
Callable
[[],
Filter
]]
def
apply
(
self
,
instances
:
L
ist
[
Instance
])
->
None
:
def
apply
(
self
,
instances
:
l
ist
[
Instance
])
->
None
:
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
resps
,
docs
=
list
(
resps
),
list
(
docs
)
...
...
lm_eval/filters/extraction.py
View file @
3fd12675
import
re
import
sys
import
unicodedata
from
collections.abc
import
Iterable
from
lm_eval.api.filter
import
Filter
from
lm_eval.api.registry
import
register_filter
...
...
@@ -30,7 +31,9 @@ class RegexFilter(Filter):
self
.
group_select
=
group_select
self
.
fallback
=
fallback
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
...
...
@@ -57,57 +60,13 @@ class RegexFilter(Filter):
return
filtered_resps
@
register_filter
(
"regex_pos"
)
class
POSFilter
(
Filter
):
""" """
def
__init__
(
self
,
regex_pattern
:
str
=
r
"\['(.*?)'\]"
,
group_select
=
0
,
fallback
=
None
,
)
->
None
:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
if
fallback
is
None
:
fallback
=
[
"invalid"
]
self
.
regex_pattern
=
regex_pattern
self
.
regex
=
re
.
compile
(
regex_pattern
)
self
.
group_select
=
group_select
self
.
fallback
=
fallback
def
apply
(
self
,
resps
,
docs
):
def
extract_tagged_tokens
(
text
):
# Extract tagged tokens list from text input using regex
tokens
=
re
.
findall
(
r
"\('([^']*)', '([^']*)'\)"
,
text
)
return
[(
token
,
pos
)
for
token
,
pos
in
tokens
]
def
extract_pos_tags
(
result
):
pos_tags
=
[]
if
isinstance
(
result
,
str
):
result
=
extract_tagged_tokens
(
result
)
pos_tags
.
extend
(
pos
for
_
,
pos
in
result
)
return
pos_tags
if
pos_tags
else
self
.
fallback
def
filter_set
(
inst
):
filtered
=
[]
for
resp
in
inst
:
match
=
extract_pos_tags
(
resp
)
filtered
.
append
(
match
)
return
filtered
filtered_resps
=
map
(
lambda
x
:
filter_set
(
x
),
resps
)
return
filtered_resps
@
register_filter
(
"remove_whitespace"
)
class
WhitespaceFilter
(
Filter
):
"""Filters out leading whitespace from responses."""
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
def
filter_set
(
inst
):
filtered_resp
=
[]
for
resp
in
inst
:
...
...
@@ -152,7 +111,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self
.
ignore_punctuation
=
ignore_punctuation
self
.
regexes_to_ignore
=
regexes_to_ignore
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment