Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
04e74420
Commit
04e74420
authored
Jul 03, 2025
by
Baber
Browse files
cleanup
parent
b0173d57
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
93 additions
and
94 deletions
+93
-94
lm_eval/api/model.py
lm_eval/api/model.py
+11
-5
lm_eval/api/registry.py
lm_eval/api/registry.py
+6
-2
lm_eval/api/task.py
lm_eval/api/task.py
+69
-76
lm_eval/filters/decontamination.py
lm_eval/filters/decontamination.py
+2
-1
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+4
-0
lm_eval/filters/selection.py
lm_eval/filters/selection.py
+0
-1
lm_eval/filters/transformation.py
lm_eval/filters/transformation.py
+1
-9
No files found.
lm_eval/api/model.py
View file @
04e74420
...
...
@@ -176,14 +176,14 @@ class LM(abc.ABC):
return
cls
(
**
arg_dict
,
**
additional_config
)
@
property
def
rank
(
self
):
def
rank
(
self
)
->
int
:
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return
self
.
_rank
@
property
def
world_size
(
self
):
def
world_size
(
self
)
->
int
:
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
...
...
@@ -233,7 +233,7 @@ class CacheHook:
class
CachingLM
:
def
__init__
(
self
,
lm
:
LM
,
cache_db
:
str
)
->
None
:
def
__init__
(
self
,
lm
:
"
LM
"
,
cache_db
:
str
)
->
None
:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
...
...
@@ -327,11 +327,11 @@ class TemplateLM(LM):
@
property
@
abc
.
abstractmethod
def
eot_token_id
(
self
):
def
eot_token_id
(
self
)
->
int
:
pass
@
property
def
prefix_token_id
(
self
):
def
prefix_token_id
(
self
)
->
int
:
# it is used as prefix for loglikelihood
return
self
.
eot_token_id
...
...
@@ -351,6 +351,11 @@ class TemplateLM(LM):
def
_encode_pair
(
self
,
context
:
str
,
continuation
:
str
)
->
tuple
[
list
[
int
],
list
[
int
]]:
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
"""
import
transformers
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
...
...
@@ -402,6 +407,7 @@ class TemplateLM(LM):
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]:
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
...
...
lm_eval/api/registry.py
View file @
04e74420
...
...
@@ -8,6 +8,10 @@ if TYPE_CHECKING:
eval_logger
=
logging
.
getLogger
(
__name__
)
MODEL_REGISTRY
=
{}
DEFAULTS
=
{
"model"
:
{
"max_length"
:
2048
},
"tasks"
:
{
"generate_until"
:
{
"max_length"
:
2048
}},
}
def
register_model
(
*
names
):
...
...
@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
)
->
Optional
[
bool
]:
def
is_higher_better
(
metric_name
:
str
)
->
Optional
[
bool
]:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
...
...
@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]:
)
def
register_filter
(
name
):
def
register_filter
(
name
:
str
):
def
decorate
(
cls
):
if
name
in
FILTER_REGISTRY
:
eval_logger
.
info
(
...
...
lm_eval/api/task.py
View file @
04e74420
...
...
@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field
from
functools
import
cached_property
from
inspect
import
getsource
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
...
...
@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [
"generate_until"
,
]
if
TYPE_CHECKING
:
from
lm_eval.api.filter
import
FilterEnsemble
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -81,7 +86,7 @@ class MetricConfig:
return
is_higher_better
(
self
.
name
)
return
self
.
higher_is_better
def
c
alcula
te_metric
(
self
,
*
args
,
**
kwargs
)
->
Any
:
def
c
ompu
te_metric
(
self
,
*
args
,
**
kwargs
)
->
Any
:
"""Calculates the metric using the provided function and arguments."""
if
self
.
fn
is
None
:
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
...
...
@@ -99,7 +104,7 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats
:
int
=
1
metric_fn
:
Optional
[
Callable
]
=
None
metric_fn
:
Optional
[
str
,
Callable
]
=
"pass@N"
kwargs
:
Optional
[
dict
]
=
None
...
...
@@ -246,15 +251,15 @@ class TaskConfig(dict):
output_type
:
OutputType
=
"generate_until"
generation_kwargs
:
Optional
[
dict
]
=
None
repeats
:
int
=
1
filter_list
:
Optional
[
Union
[
str
,
lis
t
]]
=
None
filter_list
:
Optional
[
list
[
dic
t
]]
=
None
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
_metric_list
=
None
_filter_list
=
None
_metric_list
:
list
[
MetricConfig
]
=
None
_filter_list
:
list
[
FilterConfig
]
=
None
def
__post_init__
(
self
)
->
None
:
if
self
.
generation_kwargs
is
not
None
:
...
...
@@ -289,16 +294,13 @@ class TaskConfig(dict):
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
)
if
self
.
metric_list
is
not
None
:
for
metric_config
in
self
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
if
self
.
metric_list
and
not
all
(
"metric"
in
cfg
for
cfg
in
self
.
metric_list
):
raise
ValueError
(
"each entry in metric_list must include a 'metric' key"
)
def
get_metrics
(
self
)
->
list
[
"MetricConfig"
]:
metrics
=
[]
if
self
.
metric_list
is
None
:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
eval_logger
.
info
(
f
"No metrics defined in config, using default metrics for
{
self
.
output_type
}
=
{
_metric_list
}
"
...
...
@@ -313,11 +315,8 @@ class TaskConfig(dict):
for
metric_name
in
_metric_list
)
else
:
# ---------- 2. How will the samples be evaluated ----------
for
metric_config
in
self
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name
=
metric_config
[
"metric"
]
_metric_fn_kwargs
=
{
key
:
metric_config
[
key
]
...
...
@@ -379,34 +378,30 @@ class TaskConfig(dict):
)
return
metrics
def
get_filters
(
self
):
if
self
.
filter_list
is
not
None
:
_filter_list
=
[]
if
isinstance
(
self
.
filter_list
,
dict
):
for
filter_config
in
self
.
filter_list
:
_filter_list
.
append
(
build_filter_ensemble
(
filter_name
=
filter_config
[
"name"
],
components
=
[
[
{
key
:
function
[
key
]
for
key
in
function
if
key
!=
"function"
}
]
for
function
in
filter_config
[
"filter"
]
],
)
)
else
:
# TODO: handle repeats in a more general way rather than just discarding
def
get_filters
(
self
)
->
list
[
"FilterEnsemble"
]:
if
not
self
.
filter_list
:
eval_logger
.
debug
(
"No custom filters defined
. Using default
'take_first'
filter
for handling repeats."
"No custom filters defined
; falling back to
'take_first' for handling repeats."
)
_filter_list
=
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
return
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
else
:
return
_filter_list
def
_strip_fn
(
d
:
dict
)
->
dict
:
return
{
k
:
v
for
k
,
v
in
d
.
items
()
if
k
!=
"function"
}
configs
=
(
self
.
filter_list
.
values
()
if
isinstance
(
self
.
filter_list
,
dict
)
else
self
.
filter_list
)
return
[
build_filter_ensemble
(
filter_name
=
cfg
[
"name"
],
components
=
[[
_strip_fn
(
f
)
for
f
in
cfg
[
"filter"
]]],
)
for
cfg
in
configs
]
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
...
...
@@ -415,31 +410,27 @@ class TaskConfig(dict):
return
setattr
(
self
,
item
,
value
)
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
"""Return a printable dict with Nones stripped and callables serialised.
:return: dict
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
cfg_dict
=
asdict
(
self
)
# remove values that are `None`
for
k
,
v
in
list
(
cfg_dict
.
items
()):
if
v
is
None
:
cfg_dict
.
pop
(
k
)
elif
k
==
"metric_list"
:
for
metric_dict
in
v
:
for
metric_key
,
metric_value
in
metric_dict
.
items
():
if
callable
(
metric_value
):
metric_dict
[
metric_key
]
=
self
.
serialize_function
(
metric_value
,
keep_callable
=
keep_callable
)
cfg_dict
[
k
]
=
v
elif
callable
(
v
):
cfg_dict
[
k
]
=
self
.
serialize_function
(
v
,
keep_callable
=
keep_callable
)
return
cfg_dict
def
_maybe_serialize
(
val
):
return
(
self
.
serialize_function
(
val
,
keep_callable
=
keep_callable
)
if
callable
(
val
)
else
val
)
cfg
=
asdict
(
self
)
return
{
k
:
[{
mk
:
_maybe_serialize
(
mv
)
for
mk
,
mv
in
md
.
items
()}
for
md
in
v
]
if
k
==
"metric_list"
else
_maybe_serialize
(
v
)
for
k
,
v
in
cfg
.
items
()
if
v
is
not
None
}
def
serialize_function
(
self
,
value
:
Union
[
Callable
,
str
],
keep_callable
=
False
...
...
@@ -627,7 +618,7 @@ class Task(abc.ABC):
return
doc
@
property
def
instances
(
self
)
->
L
ist
[
Instance
]:
def
instances
(
self
)
->
l
ist
[
Instance
]:
"""After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated.
"""
...
...
@@ -639,27 +630,27 @@ class Task(abc.ABC):
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
def
doc_to_decontamination_query
(
self
,
doc
):
def
doc_to_decontamination_query
(
self
,
doc
:
dict
):
raise
NotImplementedError
(
"Override doc_to_decontamination_query with document specific decontamination query."
)
@
abc
.
abstractmethod
def
doc_to_text
(
self
,
doc
)
->
str
:
def
doc_to_text
(
self
,
doc
:
dict
)
->
str
:
pass
@
abc
.
abstractmethod
def
doc_to_target
(
self
,
doc
)
->
Union
[
str
,
int
]:
def
doc_to_target
(
self
,
doc
:
dict
)
->
Union
[
str
,
int
]:
pass
# not an abstractmethod because not every language-only task has to implement this
def
doc_to_image
(
self
,
doc
):
def
doc_to_image
(
self
,
doc
:
dict
):
raise
NotImplementedError
def
doc_to_audio
(
self
,
doc
):
def
doc_to_audio
(
self
,
doc
:
dict
):
raise
NotImplementedError
def
doc_to_prefix
(
self
,
doc
)
->
str
:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
str
:
return
""
def
build_all_requests
(
...
...
@@ -776,7 +767,7 @@ class Task(abc.ABC):
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
@
abc
.
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
Union
[
list
[
dict
],
str
]
,
**
kwargs
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
...
...
@@ -797,7 +788,7 @@ class Task(abc.ABC):
pass
@
abc
.
abstractmethod
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
:
dict
,
results
:
list
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
...
...
@@ -1450,7 +1441,7 @@ class ConfigurableTask(Task):
"""
return
doc
def
doc_to_text
(
self
,
doc
,
doc_to_text
=
None
):
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
Optional
[
int
,
str
,
Callable
]
=
None
):
if
self
.
prompt
is
not
None
:
doc_to_text
=
self
.
prompt
elif
doc_to_text
is
not
None
:
...
...
@@ -1486,7 +1477,7 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
raise
TypeError
def
doc_to_target
(
self
,
doc
:
Mapping
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
]:
if
self
.
prompt
is
not
None
:
doc_to_target
=
self
.
prompt
elif
doc_to_target
is
not
None
:
...
...
@@ -1532,7 +1523,9 @@ class ConfigurableTask(Task):
else
:
raise
TypeError
def
doc_to_choice
(
self
,
doc
:
Any
,
doc_to_choice
=
None
)
->
List
[
str
]:
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
Union
[
str
,
list
,
dict
]
=
None
)
->
List
[
str
]:
if
self
.
prompt
is
not
None
:
doc_to_choice
=
self
.
prompt
elif
doc_to_choice
is
not
None
:
...
...
@@ -1558,7 +1551,7 @@ class ConfigurableTask(Task):
else
:
raise
TypeError
def
doc_to_image
(
self
,
doc
:
Any
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
elif
self
.
config
.
doc_to_image
is
not
None
:
...
...
@@ -1604,7 +1597,7 @@ class ConfigurableTask(Task):
else
:
return
None
def
doc_to_prefix
(
self
,
doc
)
->
Optional
[
str
]:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
Optional
[
str
]:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
gen_prefix
in
self
.
features
:
return
doc
[
gen_prefix
]
...
...
@@ -1713,7 +1706,7 @@ class ConfigurableTask(Task):
**
kwargs
,
)
def
process_results
(
self
,
doc
,
results
)
:
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
:
if
callable
(
self
.
config
.
process_results
):
return
self
.
config
.
process_results
(
doc
,
results
)
...
...
lm_eval/filters/decontamination.py
View file @
04e74420
...
...
@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter):
name
=
"track_decontamination"
def
__init__
(
self
,
path
)
->
None
:
def
__init__
(
self
,
path
,
**
kwargs
)
->
None
:
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id)
"""
super
().
__init__
(
**
kwargs
)
self
.
_decontam_results
=
None
def
apply
(
self
,
resps
,
docs
)
->
None
:
...
...
lm_eval/filters/extraction.py
View file @
04e74420
...
...
@@ -20,11 +20,13 @@ class RegexFilter(Filter):
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
:
int
=
0
,
fallback
:
str
=
"[invalid]"
,
**
kwargs
,
)
->
None
:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super
().
__init__
(
**
kwargs
)
self
.
regex_pattern
=
regex_pattern
self
.
regex
=
re
.
compile
(
regex_pattern
)
self
.
group_select
=
group_select
...
...
@@ -66,11 +68,13 @@ class POSFilter(Filter):
regex_pattern
:
str
=
r
"\['(.*?)'\]"
,
group_select
=
0
,
fallback
=
None
,
**
kwargs
,
)
->
None
:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super
().
__init__
(
**
kwargs
)
if
fallback
is
None
:
fallback
=
[
"invalid"
]
self
.
regex_pattern
=
regex_pattern
...
...
lm_eval/filters/selection.py
View file @
04e74420
...
...
@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class
TakeKFilter
(
Filter
):
def
__init__
(
self
,
**
kwargs
)
->
None
:
self
.
k
=
kwargs
.
pop
(
"k"
)
super
().
__init__
(
**
kwargs
)
def
apply
(
self
,
resps
,
docs
):
...
...
lm_eval/filters/transformation.py
View file @
04e74420
...
...
@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter
@
register_filter
(
"lowercase"
)
class
LowercaseFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
pass
def
apply
(
self
,
resps
,
docs
):
def
filter_set
(
inst
):
return
[
resp
.
lower
()
for
resp
in
inst
]
...
...
@@ -18,9 +15,6 @@ class LowercaseFilter(Filter):
@
register_filter
(
"uppercase"
)
class
UppercaseFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
pass
def
apply
(
self
,
resps
,
docs
):
def
filter_set
(
inst
):
return
[
resp
.
upper
()
for
resp
in
inst
]
...
...
@@ -31,6 +25,7 @@ class UppercaseFilter(Filter):
@
register_filter
(
"map"
)
class
MapFilter
(
Filter
):
def
__init__
(
self
,
mapping_dict
:
dict
=
None
,
default_value
=
None
)
->
None
:
super
().
__init__
()
"""
Initializes the MapFilter with a given mapping dictionary and default value.
...
...
@@ -60,9 +55,6 @@ class MapFilter(Filter):
@
register_filter
(
"format_span"
)
class
SPANFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
pass
def
apply
(
self
,
resps
,
docs
):
def
format_ner_text
(
text
):
label_dict
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment