Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1768fd3b
"src/vscode:/vscode.git/clone" did not exist on "836974229bdf0e2d329bdfdb0f9c4920eae224b6"
Commit
1768fd3b
authored
Jul 19, 2025
by
Baber
Browse files
ruff rules; types
parent
f650197a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
63 additions
and
44 deletions
+63
-44
.pre-commit-config.yaml
.pre-commit-config.yaml
+3
-4
lm_eval/api/group.py
lm_eval/api/group.py
+3
-3
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+31
-25
lm_eval/api/task.py
lm_eval/api/task.py
+13
-4
lm_eval/config/utils.py
lm_eval/config/utils.py
+2
-2
lm_eval/evaluator.py
lm_eval/evaluator.py
+5
-3
pyproject.toml
pyproject.toml
+6
-3
No files found.
.pre-commit-config.yaml
View file @
1768fd3b
...
...
@@ -32,10 +32,9 @@ repos:
rev
:
v0.12.2
hooks
:
# Run the linter.
-
id
:
ruff
args
:
-
--fix
# Run the formatter.
-
id
:
ruff-check
args
:
[
--fix
]
# Run the formatter.
-
id
:
ruff-format
-
repo
:
https://github.com/codespell-project/codespell
rev
:
v2.4.1
...
...
lm_eval/api/group.py
View file @
1768fd3b
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
@
dataclass
class
AggMetricConfig
(
dict
):
metric
:
Optional
[
str
]
=
None
aggregation
:
Optional
[
str
]
=
"mean"
weight_by_size
:
Optional
[
str
]
=
False
weight_by_size
:
bool
=
False
# list of filter names which should be incorporated into the aggregated metric.
filter_list
:
Optional
[
Union
[
str
,
list
]]
=
"none"
...
...
@@ -27,7 +27,7 @@ class GroupConfig(dict):
group_alias
:
Optional
[
str
]
=
None
task
:
Optional
[
Union
[
str
,
list
]]
=
None
aggregate_metric_list
:
Optional
[
Union
[
L
ist
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
Union
[
l
ist
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
]
=
None
version
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
(
...
...
lm_eval/api/metrics.py
View file @
1768fd3b
from
__future__
import
annotations
import
logging
import
math
import
os
import
random
import
re
import
string
from
collections.abc
import
Iterabl
e
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
TypeVar
from
collections.abc
import
Callable
,
Iterable
,
Sequenc
e
from
typing
import
Generic
,
TypeVar
import
numpy
as
np
...
...
@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float:
@
register_aggregation
(
"mean"
)
def
mean
(
arr
:
list
[
float
])
->
float
:
def
mean
(
arr
:
Sequence
[
float
])
->
float
:
return
sum
(
arr
)
/
len
(
arr
)
...
...
@@ -70,7 +72,7 @@ def f1_score(items):
@
register_aggregation
(
"matthews_corrcoef"
)
def
matthews_corrcoef
(
items
)
:
def
matthews_corrcoef
(
items
:
Iterable
[
tuple
[
int
,
int
]
|
tuple
[
str
,
str
]])
->
float
:
from
sklearn.metrics
import
matthews_corrcoef
unzipped_list
=
list
(
zip
(
*
items
))
...
...
@@ -80,7 +82,7 @@ def matthews_corrcoef(items):
@
register_aggregation
(
"bleu"
)
def
bleu
(
items
):
def
bleu
(
items
:
Iterable
[
tuple
[
str
,
str
]]
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
...
...
@@ -117,7 +119,7 @@ def chrf(items):
@
register_aggregation
(
"ter"
)
def
ter
(
items
):
def
ter
(
items
:
Iterable
[
tuple
[
str
,
str
]]
):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
...
...
@@ -135,7 +137,9 @@ def ter(items):
@
register_aggregation
(
"brier_score"
)
def
brier_score
(
items
):
# This is a passthrough function
def
brier_score
(
items
:
Iterable
[
tuple
[
str
,
float
]],
):
# This is a passthrough function
gold
,
predictions
=
list
(
zip
(
*
items
))
bs
,
num_class
=
np
.
array
(
predictions
).
shape
...
...
@@ -203,8 +207,8 @@ def acc_mutual_info_fn(items): # This is a passthrough function
# See the License for the specific language governing permissions and
# limitations under the License.
def
exact_match_hf_evaluate
(
predictions
,
references
,
predictions
:
Iterable
[
str
]
,
references
:
Iterable
[
str
]
,
regexes_to_ignore
=
None
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
...
...
@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
)
def
word_perplexity_fn
(
items
)
:
# This is a passthrough function
def
word_perplexity_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
...
...
@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
)
def
byte_perplexity_fn
(
items
)
:
# This is a passthrough function
def
byte_perplexity_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
...
...
@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"bits_per_byte"
,
)
def
bits_per_byte_fn
(
items
)
:
# This is a passthrough function
def
bits_per_byte_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
...
...
@@ -295,7 +299,7 @@ def pop_stddev(arr):
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
len
(
arr
))
def
sample_stddev
(
arr
:
Sequence
[
T
])
->
float
:
def
sample_stddev
(
arr
:
Sequence
[
float
])
->
float
:
mu
=
mean
(
arr
)
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
(
len
(
arr
)
-
1
))
...
...
@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return
max
(
scores_for_ground_truths
)
def
weighted_mean
(
items
:
L
ist
[
tuple
[
float
,
float
]])
->
float
:
def
weighted_mean
(
items
:
l
ist
[
tuple
[
float
,
float
]])
->
float
:
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
...
...
@@ -427,15 +431,15 @@ def is_non_str_iterable(obj):
def
_sacreformat
(
refs
,
preds
):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (
L
ist[str],
L
ist[
L
ist[str])
# Sacrebleu expects (
l
ist[str],
l
ist[
l
ist[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be
L
ist[str] or
L
ist[
L
ist[str]], the outer list corresponding to preds
# Must become
L
ist[
L
ist[str]] with the inner list corresponding to preds
# We expect refs to be
l
ist[str] or
l
ist[
l
ist[str]], the outer list corresponding to preds
# Must become
l
ist[
l
ist[str]] with the inner list corresponding to preds
if
not
is_non_str_iterable
(
refs
):
refs
=
list
(
refs
)
if
not
is_non_str_iterable
(
refs
[
0
]):
...
...
@@ -443,7 +447,7 @@ def _sacreformat(refs, preds):
refs
=
list
(
zip
(
*
refs
))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be
L
ist[str] or
L
ist[
L
ist[str]]. Must become
L
ist[str]
# We expect preds to be
l
ist[str] or
l
ist[
l
ist[str]]. Must become
l
ist[str]
if
not
is_non_str_iterable
(
preds
):
preds
=
list
(
preds
)
if
is_non_str_iterable
(
preds
[
0
]):
...
...
@@ -456,7 +460,7 @@ def _sacreformat(refs, preds):
# stderr stuff
class
_bootstrap_internal
:
class
_bootstrap_internal
(
Generic
[
T
])
:
"""
Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`.
...
...
@@ -539,7 +543,7 @@ def bootstrap_stderr(
def
stderr_for_metric
(
metric
:
Callable
[[
Sequence
[
T
]],
float
],
bootstrap_iters
:
int
)
->
Optional
[
Callable
[[
Sequence
[
T
]],
float
]
]
:
)
->
Callable
[[
Sequence
[
T
]],
float
]
|
None
:
"""
Return a function that estimates the standard error of `metric(xs)`.
...
...
@@ -569,10 +573,10 @@ def stderr_for_metric(
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
return
stderr
.
get
(
metric
)
def
pooled_sample_stderr
(
stderrs
:
L
ist
[
float
],
sizes
:
L
ist
[
int
]):
def
pooled_sample_stderr
(
stderrs
:
l
ist
[
float
],
sizes
:
l
ist
[
int
]):
# Used to aggregate bootstrapped stderrs across subtasks in a group,
# when we are weighting by the size of each subtask.
#
...
...
@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
return
np
.
sqrt
(
pooled_sample_var
/
sum
(
sizes
))
def
combined_sample_stderr
(
stderrs
:
L
ist
[
float
],
sizes
:
L
ist
[
int
],
metrics
=
None
):
def
combined_sample_stderr
(
stderrs
:
l
ist
[
float
],
sizes
:
l
ist
[
int
],
metrics
=
None
):
assert
metrics
is
not
None
,
(
"Need to pass a list of each subtask's metric for this stderr aggregation"
)
...
...
@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None)
return
np
.
sqrt
(
variance
)
def
aggregate_subtask_metrics
(
metrics
,
sizes
,
weight_by_size
=
True
):
def
aggregate_subtask_metrics
(
metrics
:
list
[
float
],
sizes
:
list
[
float
],
weight_by_size
:
bool
=
True
):
# A helper function that is used to aggregate
# subtask scores cross-task.
# TODO: does not hold for non-mean aggregations
...
...
@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
assert
len
(
metrics
)
==
len
(
sizes
)
return
sum
(
[
metric
*
size
for
metric
,
size
in
zip
(
metrics
,
sizes
)
]
)
/
sum
(
sizes
)
return
sum
(
metric
*
size
for
metric
,
size
in
zip
(
metrics
,
sizes
))
/
sum
(
sizes
)
lm_eval/api/task.py
View file @
1768fd3b
...
...
@@ -1053,7 +1053,9 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
raise
TypeError
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
[
int
]]:
# if self.prompt is not None:
# doc_to_target = self.prompt
if
doc_to_target
is
not
None
:
...
...
@@ -1100,7 +1102,9 @@ class ConfigurableTask(Task):
raise
TypeError
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
Union
[
str
,
list
,
dict
,
None
]
=
None
self
,
doc
:
dict
,
doc_to_choice
:
Union
[
str
,
list
,
dict
,
Callable
[...,
list
[
str
]],
None
]
=
None
,
)
->
List
[
str
]:
# if self.prompt is not None:
# doc_to_choice = self.prompt
...
...
@@ -1123,8 +1127,8 @@ class ConfigurableTask(Task):
return
list
(
doc_to_choice
.
values
())
elif
callable
(
doc_to_choice
):
return
doc_to_choice
(
doc
)
elif
hasattr
(
doc_to_choice
,
"get_answer_choices_list"
):
return
doc_to_choice
.
get_answer_choices_list
(
doc
)
#
elif hasattr(doc_to_choice, "get_answer_choices_list"):
#
return doc_to_choice.get_answer_choices_list(doc)
else
:
raise
TypeError
...
...
@@ -1333,6 +1337,8 @@ class ConfigurableTask(Task):
raise
ValueError
# and this stores our "regular" conditional loglikelihoods
lls
=
lls
[:
len
(
choices
)]
else
:
lls_unconditional
=
None
pred
=
np
.
argmax
(
lls
)
pred_norm
=
np
.
argmax
(
lls
/
completion_len
)
...
...
@@ -1390,6 +1396,9 @@ class ConfigurableTask(Task):
}
if
"acc_mutual_info"
in
use_metric
:
assert
lls_unconditional
is
not
None
,
(
"lls_unconditional should not be None if acc_mutual_info is in use_metric"
)
lls_mutual_info
=
[
ll_c
-
ll_u
for
ll_c
,
ll_u
in
zip
(
lls
,
lls_unconditional
)
]
...
...
lm_eval/config/utils.py
View file @
1768fd3b
...
...
@@ -3,8 +3,8 @@ from typing import Any, Callable, Union
def
serialize_callable
(
value
:
Union
[
Callable
,
str
],
keep_callable
=
False
)
->
Union
[
Callable
,
str
]:
value
:
Union
[
Callable
[...,
Any
]
,
str
],
keep_callable
=
False
)
->
Union
[
Callable
[...,
Any
]
,
str
]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
...
...
lm_eval/evaluator.py
View file @
1768fd3b
from
__future__
import
annotations
import
itertools
import
json
import
logging
...
...
@@ -5,7 +7,7 @@ import os
import
random
import
time
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
...
...
@@ -49,7 +51,7 @@ eval_logger = logging.getLogger(__name__)
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
:
Optional
[
Union
[
str
,
dict
]]
=
None
,
model_args
:
Optional
[
Union
[
str
,
dict
[
str
,
Any
]
]]
=
None
,
tasks
:
Optional
[
List
[
Union
[
str
,
dict
,
object
]]]
=
None
,
num_fewshot
:
Optional
[
int
]
=
None
,
batch_size
:
Optional
[
Union
[
int
,
str
]]
=
None
,
...
...
@@ -420,7 +422,7 @@ def simple_evaluate(
def
evaluate
(
lm
:
"LM"
,
task_dict
,
limit
:
Optional
[
int
]
=
None
,
limit
:
int
|
float
|
None
=
None
,
samples
:
Optional
[
dict
]
=
None
,
cache_requests
:
bool
=
False
,
rewrite_requests_cache
:
bool
=
False
,
...
...
pyproject.toml
View file @
1768fd3b
...
...
@@ -107,16 +107,19 @@ plugins.md028.enabled = false # no-blanks-blockquote
plugins.md029.allow_extended_start_values
=
true
# ol-prefix
plugins.md034.enabled
=
false
# no-bare-urls
[tool.ruff.lint]
extend-select
=
[
"I"
,
"W605"
]
[tool.ruff]
target-version
=
"py39"
lint.extend-select
=
[
"I"
,
"UP"
,
"E"
,
"C419"
,
"F"
,
"B"
,
"SIM"
]
lint.ignore
=
[
"E402"
,
"E731"
,
"E501"
,
"E111"
,
"E114"
,
"E117"
]
[tool.ruff.lint.isort]
combine-as-imports
=
true
lines-after-imports
=
2
known-first-party
=
["lm_eval"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
["F401","F402","F403"]
"utils.py"
=
["F401"]
[dependency-groups]
dev
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment