Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
312374bc
Commit
312374bc
authored
Jul 07, 2025
by
Baber
Browse files
type hints
parent
90cf3b89
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
70 additions
and
70 deletions
+70
-70
lm_eval/api/filter.py
lm_eval/api/filter.py
+7
-4
lm_eval/api/task.py
lm_eval/api/task.py
+43
-33
lm_eval/evaluator.py
lm_eval/evaluator.py
+3
-4
lm_eval/filters/custom.py
lm_eval/filters/custom.py
+0
-1
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+4
-0
lm_eval/filters/selection.py
lm_eval/filters/selection.py
+6
-15
lm_eval/filters/transformation.py
lm_eval/filters/transformation.py
+2
-10
lm_eval/utils.py
lm_eval/utils.py
+5
-3
No files found.
lm_eval/api/filter.py
View file @
312374bc
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Iterable
,
List
,
Union
from
typing
import
Callable
,
Iterable
,
TypeVar
from
lm_eval.api.instance
import
Instance
T
=
TypeVar
(
"T"
)
class
Filter
(
ABC
):
"""
Filter classes operate on a per-task level.
...
...
@@ -20,7 +23,7 @@ class Filter(ABC):
"""
@
abstractmethod
def
apply
(
self
,
resps
:
Union
[
List
,
Iterable
],
docs
:
L
ist
[
dict
])
->
Iterable
:
def
apply
(
self
,
resps
:
Iterable
[
list
[
T
]
],
docs
:
l
ist
[
dict
])
->
Iterable
[
list
[
T
]]
:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...
...
@@ -40,9 +43,9 @@ class FilterEnsemble:
"""
name
:
str
filters
:
L
ist
[
Callable
[[],
Filter
]]
filters
:
l
ist
[
Callable
[[],
Filter
]]
def
apply
(
self
,
instances
:
L
ist
[
Instance
])
->
None
:
def
apply
(
self
,
instances
:
l
ist
[
Instance
])
->
None
:
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
# TODO: add backward
# unwrap responses from GenerateOutput as the filters expect strings
...
...
lm_eval/api/task.py
View file @
312374bc
import
abc
import
ast
import
itertools
import
logging
import
random
import
re
...
...
@@ -109,24 +110,19 @@ class TaskConfig(dict):
)
def
__post_init__
(
self
)
->
None
:
if
self
.
generation_kwargs
is
not
None
:
if
self
.
output_type
!=
"generate_until"
:
eval_logger
.
warning
(
f
"[
{
self
.
task
}
] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
if
"temperature"
in
self
.
generation_kwargs
:
self
.
generation_kwargs
[
"temperature"
]
=
float
(
self
.
generation_kwargs
[
"temperature"
]
)
if
self
.
output_type
==
"generate_until"
:
if
self
.
generation_kwargs
is
not
None
:
if
"temperature"
in
self
.
generation_kwargs
:
self
.
generation_kwargs
[
"temperature"
]
=
float
(
self
.
generation_kwargs
[
"temperature"
]
)
if
"until"
not
in
self
.
generation_kwargs
:
eval_logger
.
warning
(
f
"
{
self
.
task
}
: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter=
{
repr
(
self
.
fewshot_delimiter
)
}
"
)
self
.
generation_kwargs
[
"until"
]
=
[
self
.
fewshot_delimiter
]
else
:
if
self
.
output_type
==
"generate_until"
:
if
"until"
not
in
self
.
generation_kwargs
:
eval_logger
.
warning
(
f
"
{
self
.
task
}
: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter=
{
repr
(
self
.
fewshot_delimiter
)
}
"
)
self
.
generation_kwargs
[
"until"
]
=
[
self
.
fewshot_delimiter
]
else
:
# ensure that we greedily generate in absence of explicit arguments otherwise
self
.
generation_kwargs
=
{
"until"
:
(
...
...
@@ -140,6 +136,11 @@ class TaskConfig(dict):
eval_logger
.
warning
(
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
)
else
:
if
self
.
generation_kwargs
is
not
None
:
eval_logger
.
warning
(
f
"[
{
self
.
task
}
] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
...
...
@@ -1558,7 +1559,7 @@ class ConfigurableTask(Task):
**
kwargs
,
)
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
)
->
dict
:
if
callable
(
self
.
config
.
process_results
):
return
self
.
config
.
process_results
(
doc
,
results
)
...
...
@@ -1779,11 +1780,11 @@ class ConfigurableTask(Task):
def
compute_sample_metrics
(
self
,
requests
:
list
[
Instance
]
=
None
,
filter_keys
:
list
[
str
]
=
None
,
indices
:
list
[
int
]
=
None
,
requests
:
Optional
[
list
[
Instance
]
]
=
None
,
filter_keys
:
Optional
[
list
[
str
]
]
=
None
,
indices
:
Optional
[
list
[
int
]
]
=
None
,
rank
:
int
=
1
,
limit
:
int
=
None
,
limit
:
Optional
[
int
]
=
None
,
world_size
:
int
=
1
,
log_samples
:
bool
=
False
,
)
->
tuple
[
...
...
@@ -1807,6 +1808,9 @@ class ConfigurableTask(Task):
else
:
requests
=
requests
if
requests
else
self
.
instances
all_metrics
=
defaultdict
(
list
)
samples
=
[]
if
log_samples
else
None
### Collect values of metrics on all datapoints ###
# Pre-process task.instances to group by doc_id
instances_by_doc_id
=
defaultdict
(
list
)
...
...
@@ -1815,8 +1819,6 @@ class ConfigurableTask(Task):
# Sort instances within each group
for
instances
in
instances_by_doc_id
.
values
():
instances
.
sort
(
key
=
lambda
x
:
x
.
idx
)
_all_metrics
=
defaultdict
(
list
)
_samples
=
[]
if
log_samples
else
None
if
filter_keys
is
None
:
filter_keys
=
(
...
...
@@ -1840,9 +1842,16 @@ class ConfigurableTask(Task):
requests
=
instances_by_doc_id
[
_doc_id_true
]
if
self
.
OUTPUT_TYPE
!=
"generate_until"
:
# if one doc has multiple instances then calculate metric together
metrics
=
self
.
process_results
(
doc
,
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
]
)
metrics
=
[
self
.
process_results
(
doc
,
list
(
itertools
.
chain
.
from_iterable
(
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
]
)
),
)
]
else
:
metrics
=
[
self
.
process_results
(
doc
,
response
)
...
...
@@ -1857,20 +1866,21 @@ class ConfigurableTask(Task):
for
k
,
v
in
metric
.
items
():
_sample_metric
[
k
].
append
(
v
)
if
log_samples
:
_
samples
.
append
(
samples
.
append
(
create_sample_log
(
doc
=
doc
,
doc_id
=
_doc_id_true
,
target
=
self
.
doc_to_target
(
doc
),
requests
=
requests
,
metric_names
=
metrics
,
requests
=
tuple
(
requests
)
,
metric_names
=
tuple
(
str
(
x
)
for
x
in
metrics
[
0
])
,
filter_key
=
filter_key
,
metrics
=
tuple
(
metrics
),
)
)
for
metric_name
,
_score
in
_sample_metric
.
items
():
_
all_metrics
[(
metric_name
,
filter_key
)].
append
(
_score
)
self
.
metric_results
=
_
all_metrics
return
_
all_metrics
,
_
samples
all_metrics
[(
metric_name
,
filter_key
)].
append
(
_score
)
self
.
metric_results
=
all_metrics
return
all_metrics
,
samples
def
compute_agg_metrics
(
self
,
...
...
lm_eval/evaluator.py
View file @
312374bc
...
...
@@ -352,8 +352,6 @@ def simple_evaluate(
verbosity
=
verbosity
,
confirm_run_unsafe_code
=
confirm_run_unsafe_code
,
)
if
verbosity
is
not
None
:
setup_logging
(
verbosity
=
verbosity
)
if
lm
.
rank
==
0
:
if
isinstance
(
model
,
str
):
...
...
@@ -588,14 +586,13 @@ def evaluate(
### Collect values of metrics on all datapoints ###
# # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter
_metrics
,
samples
=
task
.
compute_sample_metrics
(
task_output
.
sample
_metrics
,
samples
=
task
.
compute_sample_metrics
(
indices
=
samples
,
rank
=
RANK
,
limit
=
limit
,
world_size
=
WORLD_SIZE
,
log_samples
=
log_samples
,
)
task_output
.
sample_metrics
=
_metrics
if
log_samples
:
task_output
.
logged_samples
=
samples
...
...
@@ -606,6 +603,7 @@ def evaluate(
if
log_samples
:
# for task_name, task_samples in list(samples.items()):
full_samples
=
[
None
]
*
WORLD_SIZE
if
RANK
==
0
else
None
eval_logger
.
info
(
task_output
.
logged_samples
)
torch
.
distributed
.
gather_object
(
obj
=
task_output
.
logged_samples
,
object_gather_list
=
full_samples
,
...
...
@@ -620,6 +618,7 @@ def evaluate(
# then collect metrics across all ranks
for
metrics
in
task_output
.
sample_metrics
:
metric_list
=
[
None
]
*
WORLD_SIZE
if
RANK
==
0
else
None
eval_logger
.
info
(
task_output
.
sample_metrics
[
metrics
])
torch
.
distributed
.
gather_object
(
obj
=
task_output
.
sample_metrics
[
metrics
],
object_gather_list
=
metric_list
,
...
...
lm_eval/filters/custom.py
View file @
312374bc
...
...
@@ -10,7 +10,6 @@ class CustomFilter(Filter):
def
__init__
(
self
,
**
kwargs
)
->
None
:
self
.
filter_fn
=
kwargs
.
pop
(
"filter_fn"
)
super
().
__init__
(
**
kwargs
)
def
apply
(
self
,
resps
,
docs
):
...
...
lm_eval/filters/extraction.py
View file @
312374bc
...
...
@@ -20,11 +20,13 @@ class RegexFilter(Filter):
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
:
int
=
0
,
fallback
:
str
=
"[invalid]"
,
**
kwargs
,
)
->
None
:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super
().
__init__
(
**
kwargs
)
self
.
regex_pattern
=
regex_pattern
self
.
regex
=
re
.
compile
(
regex_pattern
)
self
.
group_select
=
group_select
...
...
@@ -66,11 +68,13 @@ class POSFilter(Filter):
regex_pattern
:
str
=
r
"\['(.*?)'\]"
,
group_select
=
0
,
fallback
=
None
,
**
kwargs
,
)
->
None
:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super
().
__init__
(
**
kwargs
)
if
fallback
is
None
:
fallback
=
[
"invalid"
]
self
.
regex_pattern
=
regex_pattern
...
...
lm_eval/filters/selection.py
View file @
312374bc
from
collections
import
Counter
from
typing
import
Iterable
,
TypeVar
from
lm_eval.api.filter
import
Filter
from
lm_eval.api.registry
import
register_filter
T
=
TypeVar
(
"T"
)
# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
...
...
@@ -11,26 +13,20 @@ from lm_eval.api.registry import register_filter
@
register_filter
(
"take_first"
)
class
TakeFirstFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def
apply
(
self
,
resps
,
docs
):
def
apply
(
self
,
resps
:
Iterable
[
list
[
T
]],
docs
:
list
[
dict
])
->
Iterable
[
list
[
T
]]:
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return
map
(
lambda
r
:
r
,
resps
)
return
map
(
lambda
r
:
[
r
[
0
]]
,
resps
)
@
register_filter
(
"take_first_k"
)
class
TakeKFilter
(
Filter
):
def
__init__
(
self
,
**
kwargs
)
->
None
:
self
.
k
=
kwargs
.
pop
(
"k"
)
super
().
__init__
(
**
kwargs
)
def
apply
(
self
,
resps
,
docs
)
:
def
apply
(
self
,
resps
:
Iterable
[
list
[
T
]],
docs
:
list
[
dict
])
->
Iterable
[
list
[
T
]]
:
# need resp to be subscriptable to check below
resps
=
list
(
resps
)
# check we have at least k responses per doc, else we can't take the first k
...
...
@@ -42,12 +38,7 @@ class TakeKFilter(Filter):
@
register_filter
(
"majority_vote"
)
class
MajorityVoteFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def
apply
(
self
,
resps
,
docs
):
def
apply
(
self
,
resps
:
Iterable
[
list
[
T
]],
docs
:
list
[
dict
])
->
Iterable
[
list
[
T
]]:
"""
Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`.
...
...
lm_eval/filters/transformation.py
View file @
312374bc
...
...
@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter
@
register_filter
(
"lowercase"
)
class
LowercaseFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
pass
def
apply
(
self
,
resps
,
docs
):
def
filter_set
(
inst
):
return
[
resp
.
lower
()
for
resp
in
inst
]
...
...
@@ -18,9 +15,6 @@ class LowercaseFilter(Filter):
@
register_filter
(
"uppercase"
)
class
UppercaseFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
pass
def
apply
(
self
,
resps
,
docs
):
def
filter_set
(
inst
):
return
[
resp
.
upper
()
for
resp
in
inst
]
...
...
@@ -30,7 +24,7 @@ class UppercaseFilter(Filter):
@
register_filter
(
"map"
)
class
MapFilter
(
Filter
):
def
__init__
(
self
,
mapping_dict
:
dict
=
None
,
default_value
=
None
)
->
None
:
def
__init__
(
self
,
mapping_dict
:
dict
=
None
,
default_value
=
None
,
**
kwargs
)
->
None
:
"""
Initializes the MapFilter with a given mapping dictionary and default value.
...
...
@@ -43,6 +37,7 @@ class MapFilter(Filter):
Example:
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
"""
super
().
__init__
(
**
kwargs
)
if
mapping_dict
is
None
:
mapping_dict
=
{}
assert
isinstance
(
mapping_dict
,
dict
),
(
...
...
@@ -60,9 +55,6 @@ class MapFilter(Filter):
@
register_filter
(
"format_span"
)
class
SPANFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
pass
def
apply
(
self
,
resps
,
docs
):
def
format_ner_text
(
text
):
label_dict
=
{
...
...
lm_eval/utils.py
View file @
312374bc
...
...
@@ -562,9 +562,10 @@ def create_sample_log(
doc
:
dict
,
doc_id
:
int
,
target
:
Any
,
requests
:
list
[
Instance
],
metric_names
:
[
dict
],
requests
:
tuple
[
Instance
],
metric_names
:
tuple
[
str
,
...
],
filter_key
:
str
,
metrics
:
tuple
[
dict
,
...],
)
->
dict
:
return
{
"doc_id"
:
doc_id
,
...
...
@@ -574,7 +575,8 @@ def create_sample_log(
"resps"
:
[
req
.
resps
for
req
in
requests
],
"filtered_resps"
:
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
],
"filter"
:
filter_key
,
"metrics"
:
metric_names
,
"metric_names"
:
metric_names
,
"metrics"
:
metrics
,
"doc_hash"
:
hash_string
(
json
.
dumps
(
requests
[
0
].
doc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment