Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
db5dff9c
"docs/en/FAQ.md" did not exist on "85e363584fcf17b2415e81220a1ed56ea8559cb5"
Commit
db5dff9c
authored
Jul 03, 2025
by
Baber
Browse files
type hints
parent
023bfe0d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
105 additions
and
96 deletions
+105
-96
lm_eval/api/model.py
lm_eval/api/model.py
+30
-18
lm_eval/api/registry.py
lm_eval/api/registry.py
+6
-2
lm_eval/api/task.py
lm_eval/api/task.py
+69
-76
No files found.
lm_eval/api/model.py
View file @
db5dff9c
...
...
@@ -3,7 +3,7 @@ import hashlib
import
json
import
logging
import
os
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Type
,
TypeVar
,
Union
import
transformers
from
sqlitedict
import
SqliteDict
...
...
@@ -12,6 +12,10 @@ from tqdm import tqdm
from
lm_eval
import
utils
if
TYPE_CHECKING
:
from
lm_eval.api.instance
import
Instance
eval_logger
=
logging
.
getLogger
(
__name__
)
T
=
TypeVar
(
"T"
,
bound
=
"LM"
)
...
...
@@ -30,7 +34,7 @@ class LM(abc.ABC):
self
.
cache_hook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
def
loglikelihood
(
self
,
requests
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
...
...
@@ -55,7 +59,7 @@ class LM(abc.ABC):
pass
@
abc
.
abstractmethod
def
loglikelihood_rolling
(
self
,
requests
)
->
L
ist
[
float
]:
def
loglikelihood_rolling
(
self
,
requests
)
->
l
ist
[
float
]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...
...
@@ -97,7 +101,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
)
->
L
ist
[
str
]:
def
generate_until
(
self
,
requests
)
->
l
ist
[
str
]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
...
...
@@ -114,7 +118,7 @@ class LM(abc.ABC):
pass
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
=
True
)
->
str
:
"""
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
...
...
@@ -173,14 +177,14 @@ class LM(abc.ABC):
return
cls
(
**
arg_dict
,
**
additional_config
)
@
property
def
rank
(
self
):
def
rank
(
self
)
->
int
:
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return
self
.
_rank
@
property
def
world_size
(
self
):
def
world_size
(
self
)
->
int
:
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
...
...
@@ -230,7 +234,7 @@ class CacheHook:
class
CachingLM
:
def
__init__
(
self
,
lm
,
cache_db
)
->
None
:
def
__init__
(
self
,
lm
:
"LM"
,
cache_db
:
str
)
->
None
:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
...
...
@@ -253,7 +257,7 @@ class CachingLM:
eval_logger
.
debug
(
f
"Passing through attribute '
{
attr
}
' to underlying LM"
)
return
lm_attr
def
fn
(
requests
)
:
def
fn
(
requests
:
list
[
Instance
])
->
list
[
Instance
]
:
res
=
[]
remaining_reqs
=
[]
warned
=
False
...
...
@@ -322,28 +326,35 @@ class TemplateLM(LM):
@
property
@
abc
.
abstractmethod
def
eot_token_id
(
self
):
def
eot_token_id
(
self
)
->
int
:
pass
@
property
def
prefix_token_id
(
self
):
def
prefix_token_id
(
self
)
->
int
:
# it is used as prefix for loglikelihood
return
self
.
eot_token_id
@
abc
.
abstractmethod
def
tok_encode
(
self
,
string
:
str
,
**
kwargs
)
->
L
ist
[
int
]:
def
tok_encode
(
self
,
string
:
str
,
**
kwargs
)
->
l
ist
[
int
]:
"""
Tokenize a string using the model's tokenizer and return a list of token IDs.
"""
pass
@
abc
.
abstractmethod
def
_loglikelihood_tokens
(
self
,
requests
,
**
kwargs
)
->
List
[
Tuple
[
float
,
bool
]]:
def
_loglikelihood_tokens
(
self
,
requests
:
list
[
Instance
],
**
kwargs
)
->
list
[
tuple
[
float
,
bool
]]:
pass
def
_encode_pair
(
self
,
context
:
str
,
continuation
:
str
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
)
->
tuple
[
list
[
int
],
list
[
int
]]:
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
"""
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
if
n_spaces
>
0
:
continuation
=
context
[
-
n_spaces
:]
+
continuation
...
...
@@ -364,8 +375,8 @@ class TemplateLM(LM):
return
context_enc
,
continuation_enc
def
loglikelihood
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
self
,
requests
:
list
[
Instance
]
,
disable_tqdm
:
bool
=
False
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
...
...
@@ -384,15 +395,16 @@ class TemplateLM(LM):
@
abc
.
abstractmethod
def
loglikelihood_rolling
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
pass
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
def
generate_until
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
l
ist
[
str
]:
pass
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]:
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
...
...
lm_eval/api/registry.py
View file @
db5dff9c
...
...
@@ -8,6 +8,10 @@ if TYPE_CHECKING:
eval_logger
=
logging
.
getLogger
(
__name__
)
MODEL_REGISTRY
=
{}
DEFAULTS
=
{
"model"
:
{
"max_length"
:
2048
},
"tasks"
:
{
"generate_until"
:
{
"max_length"
:
2048
}},
}
def
register_model
(
*
names
):
...
...
@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
)
->
Optional
[
bool
]:
def
is_higher_better
(
metric_name
:
str
)
->
Optional
[
bool
]:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
...
...
@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]:
)
def
register_filter
(
name
):
def
register_filter
(
name
:
str
):
def
decorate
(
cls
):
if
name
in
FILTER_REGISTRY
:
eval_logger
.
info
(
...
...
lm_eval/api/task.py
View file @
db5dff9c
...
...
@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field
from
functools
import
cached_property
from
inspect
import
getsource
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
...
...
@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [
"generate_until"
,
]
if
TYPE_CHECKING
:
from
lm_eval.api.filter
import
FilterEnsemble
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -81,7 +86,7 @@ class MetricConfig:
return
is_higher_better
(
self
.
name
)
return
self
.
higher_is_better
def
c
alcula
te_metric
(
self
,
*
args
,
**
kwargs
)
->
Any
:
def
c
ompu
te_metric
(
self
,
*
args
,
**
kwargs
)
->
Any
:
"""Calculates the metric using the provided function and arguments."""
if
self
.
fn
is
None
:
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
...
...
@@ -99,7 +104,7 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats
:
int
=
1
metric_fn
:
Optional
[
Callable
]
=
None
metric_fn
:
Optional
[
Callable
]
=
"pass@N"
kwargs
:
Optional
[
dict
]
=
None
...
...
@@ -246,15 +251,15 @@ class TaskConfig(dict):
output_type
:
OutputType
=
"generate_until"
generation_kwargs
:
Optional
[
dict
]
=
None
repeats
:
int
=
1
filter_list
:
Optional
[
Union
[
str
,
lis
t
]]
=
None
filter_list
:
Optional
[
list
[
dic
t
]]
=
None
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
_metric_list
=
None
_filter_list
=
None
_metric_list
:
list
[
MetricConfig
]
=
None
_filter_list
:
list
[
FilterConfig
]
=
None
def
__post_init__
(
self
)
->
None
:
if
self
.
generation_kwargs
is
not
None
:
...
...
@@ -289,16 +294,13 @@ class TaskConfig(dict):
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
)
if
self
.
metric_list
is
not
None
:
for
metric_config
in
self
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
if
self
.
metric_list
and
not
all
(
"metric"
in
cfg
for
cfg
in
self
.
metric_list
):
raise
ValueError
(
"each entry in metric_list must include a 'metric' key"
)
def
get_metrics
(
self
)
->
list
[
"MetricConfig"
]:
metrics
=
[]
if
self
.
metric_list
is
None
:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
eval_logger
.
info
(
f
"No metrics defined in config, using default metrics for
{
self
.
output_type
}
=
{
_metric_list
}
"
...
...
@@ -313,11 +315,8 @@ class TaskConfig(dict):
for
metric_name
in
_metric_list
)
else
:
# ---------- 2. How will the samples be evaluated ----------
for
metric_config
in
self
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name
=
metric_config
[
"metric"
]
_metric_fn_kwargs
=
{
key
:
metric_config
[
key
]
...
...
@@ -379,34 +378,30 @@ class TaskConfig(dict):
)
return
metrics
def
get_filters
(
self
):
if
self
.
filter_list
is
not
None
:
_filter_list
=
[]
if
isinstance
(
self
.
filter_list
,
dict
):
for
filter_config
in
self
.
filter_list
:
_filter_list
.
append
(
build_filter_ensemble
(
filter_name
=
filter_config
[
"name"
],
components
=
[
[
{
key
:
function
[
key
]
for
key
in
function
if
key
!=
"function"
}
]
for
function
in
filter_config
[
"filter"
]
],
)
)
else
:
# TODO: handle repeats in a more general way rather than just discarding
def
get_filters
(
self
)
->
list
[
"FilterEnsemble"
]:
if
not
self
.
filter_list
:
eval_logger
.
debug
(
"No custom filters defined
. Using default
'take_first'
filter
for handling repeats."
"No custom filters defined
; falling back to
'take_first' for handling repeats."
)
_filter_list
=
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
return
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
else
:
return
_filter_list
def
_strip_fn
(
d
:
dict
)
->
dict
:
return
{
k
:
v
for
k
,
v
in
d
.
items
()
if
k
!=
"function"
}
configs
=
(
self
.
filter_list
.
values
()
if
isinstance
(
self
.
filter_list
,
dict
)
else
self
.
filter_list
)
return
[
build_filter_ensemble
(
filter_name
=
cfg
[
"name"
],
components
=
[[
_strip_fn
(
f
)
for
f
in
cfg
[
"filter"
]]],
)
for
cfg
in
configs
]
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
...
...
@@ -415,31 +410,27 @@ class TaskConfig(dict):
return
setattr
(
self
,
item
,
value
)
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
"""Return a printable dict with Nones stripped and callables serialised.
:return: dict
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
cfg_dict
=
asdict
(
self
)
# remove values that are `None`
for
k
,
v
in
list
(
cfg_dict
.
items
()):
if
v
is
None
:
cfg_dict
.
pop
(
k
)
elif
k
==
"metric_list"
:
for
metric_dict
in
v
:
for
metric_key
,
metric_value
in
metric_dict
.
items
():
if
callable
(
metric_value
):
metric_dict
[
metric_key
]
=
self
.
serialize_function
(
metric_value
,
keep_callable
=
keep_callable
)
cfg_dict
[
k
]
=
v
elif
callable
(
v
):
cfg_dict
[
k
]
=
self
.
serialize_function
(
v
,
keep_callable
=
keep_callable
)
return
cfg_dict
def
_maybe_serialize
(
val
):
return
(
self
.
serialize_function
(
val
,
keep_callable
=
keep_callable
)
if
callable
(
val
)
else
val
)
cfg
=
asdict
(
self
)
return
{
k
:
[{
mk
:
_maybe_serialize
(
mv
)
for
mk
,
mv
in
md
.
items
()}
for
md
in
v
]
if
k
==
"metric_list"
else
_maybe_serialize
(
v
)
for
k
,
v
in
cfg
.
items
()
if
v
is
not
None
}
def
serialize_function
(
self
,
value
:
Union
[
Callable
,
str
],
keep_callable
=
False
...
...
@@ -627,7 +618,7 @@ class Task(abc.ABC):
return
doc
@
property
def
instances
(
self
)
->
L
ist
[
Instance
]:
def
instances
(
self
)
->
l
ist
[
Instance
]:
"""After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated.
"""
...
...
@@ -639,27 +630,27 @@ class Task(abc.ABC):
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
def
doc_to_decontamination_query
(
self
,
doc
):
def
doc_to_decontamination_query
(
self
,
doc
:
dict
):
raise
NotImplementedError
(
"Override doc_to_decontamination_query with document specific decontamination query."
)
@
abc
.
abstractmethod
def
doc_to_text
(
self
,
doc
)
->
str
:
def
doc_to_text
(
self
,
doc
:
dict
)
->
str
:
pass
@
abc
.
abstractmethod
def
doc_to_target
(
self
,
doc
)
->
Union
[
str
,
int
]:
def
doc_to_target
(
self
,
doc
:
dict
)
->
Union
[
str
,
int
]:
pass
# not an abstractmethod because not every language-only task has to implement this
def
doc_to_image
(
self
,
doc
):
def
doc_to_image
(
self
,
doc
:
dict
):
raise
NotImplementedError
def
doc_to_audio
(
self
,
doc
):
def
doc_to_audio
(
self
,
doc
:
dict
):
raise
NotImplementedError
def
doc_to_prefix
(
self
,
doc
)
->
str
:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
str
:
return
""
def
build_all_requests
(
...
...
@@ -776,7 +767,7 @@ class Task(abc.ABC):
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
@
abc
.
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
Union
[
list
[
dict
],
str
]
,
**
kwargs
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
...
...
@@ -797,7 +788,7 @@ class Task(abc.ABC):
pass
@
abc
.
abstractmethod
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
:
dict
,
results
:
list
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
...
...
@@ -1446,7 +1437,7 @@ class ConfigurableTask(Task):
"""
return
doc
def
doc_to_text
(
self
,
doc
,
doc_to_text
=
None
):
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
Optional
[
int
,
str
,
Callable
]
=
None
):
if
self
.
prompt
is
not
None
:
doc_to_text
=
self
.
prompt
elif
doc_to_text
is
not
None
:
...
...
@@ -1482,7 +1473,7 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
raise
TypeError
def
doc_to_target
(
self
,
doc
:
Mapping
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
]:
if
self
.
prompt
is
not
None
:
doc_to_target
=
self
.
prompt
elif
doc_to_target
is
not
None
:
...
...
@@ -1528,7 +1519,9 @@ class ConfigurableTask(Task):
else
:
raise
TypeError
def
doc_to_choice
(
self
,
doc
:
Any
,
doc_to_choice
=
None
)
->
List
[
str
]:
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
Union
[
str
,
list
,
dict
]
=
None
)
->
List
[
str
]:
if
self
.
prompt
is
not
None
:
doc_to_choice
=
self
.
prompt
elif
doc_to_choice
is
not
None
:
...
...
@@ -1554,7 +1547,7 @@ class ConfigurableTask(Task):
else
:
raise
TypeError
def
doc_to_image
(
self
,
doc
:
Any
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
elif
self
.
config
.
doc_to_image
is
not
None
:
...
...
@@ -1600,7 +1593,7 @@ class ConfigurableTask(Task):
else
:
return
None
def
doc_to_prefix
(
self
,
doc
)
->
Optional
[
str
]:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
Optional
[
str
]:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
gen_prefix
in
self
.
features
:
return
doc
[
gen_prefix
]
...
...
@@ -1709,7 +1702,7 @@ class ConfigurableTask(Task):
**
kwargs
,
)
def
process_results
(
self
,
doc
,
results
)
:
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
:
if
callable
(
self
.
config
.
process_results
):
return
self
.
config
.
process_results
(
doc
,
results
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment